1 // Copyright (c) 2016 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/opt/set_spec_constant_default_value_pass.h" 16 17 #include <algorithm> 18 #include <cctype> 19 #include <cstring> 20 #include <tuple> 21 #include <vector> 22 23 #include "source/opt/def_use_manager.h" 24 #include "source/opt/ir_context.h" 25 #include "source/opt/type_manager.h" 26 #include "source/opt/types.h" 27 #include "source/util/make_unique.h" 28 #include "source/util/parse_number.h" 29 #include "spirv-tools/libspirv.h" 30 31 namespace spvtools { 32 namespace opt { 33 34 namespace { 35 using utils::EncodeNumberStatus; 36 using utils::NumberType; 37 using utils::ParseAndEncodeNumber; 38 using utils::ParseNumber; 39 40 // Given a numeric value in a null-terminated c string and the expected type of 41 // the value, parses the string and encodes it in a vector of words. If the 42 // value is a scalar integer or floating point value, encodes the value in 43 // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector 44 // with single word with value 0 or 1 respectively. Returns the vector 45 // containing the encoded value on success. Otherwise returns an empty vector. 46 std::vector<uint32_t> ParseDefaultValueStr(const char* text, 47 const analysis::Type* type) { 48 std::vector<uint32_t> result; 49 if (!strcmp(text, "true") && type->AsBool()) { 50 result.push_back(1u); 51 } else if (!strcmp(text, "false") && type->AsBool()) { 52 result.push_back(0u); 53 } else { 54 NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT}; 55 if (const auto* IT = type->AsInteger()) { 56 number_type.bitwidth = IT->width(); 57 number_type.kind = 58 IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; 59 } else if (const auto* FT = type->AsFloat()) { 60 number_type.bitwidth = FT->width(); 61 number_type.kind = SPV_NUMBER_FLOATING; 62 } else { 63 // Does not handle types other then boolean, integer or float. Returns 64 // empty vector. 65 result.clear(); 66 return result; 67 } 68 EncodeNumberStatus rc = ParseAndEncodeNumber( 69 text, number_type, [&result](uint32_t word) { result.push_back(word); }, 70 nullptr); 71 // Clear the result vector on failure. 72 if (rc != EncodeNumberStatus::kSuccess) { 73 result.clear(); 74 } 75 } 76 return result; 77 } 78 79 // Given a bit pattern and a type, checks if the bit pattern is compatible 80 // with the type. If so, returns the bit pattern, otherwise returns an empty 81 // bit pattern. If the given bit pattern is empty, returns an empty bit 82 // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern 83 // to be returned is determined with the following standard: 84 // If any words in the input bit pattern are non zero, returns a bit pattern 85 // with 0x1, which represents a 'true'. 86 // If all words in the bit pattern are zero, returns a bit pattern with 0x0, 87 // which represents a 'false'. 88 std::vector<uint32_t> ParseDefaultValueBitPattern( 89 const std::vector<uint32_t>& input_bit_pattern, 90 const analysis::Type* type) { 91 std::vector<uint32_t> result; 92 if (type->AsBool()) { 93 if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(), 94 [](uint32_t i) { return i != 0; })) { 95 result.push_back(1u); 96 } else { 97 result.push_back(0u); 98 } 99 return result; 100 } else if (const auto* IT = type->AsInteger()) { 101 if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { 102 return std::vector<uint32_t>(input_bit_pattern); 103 } 104 } else if (const auto* FT = type->AsFloat()) { 105 if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { 106 return std::vector<uint32_t>(input_bit_pattern); 107 } 108 } 109 result.clear(); 110 return result; 111 } 112 113 // Returns true if the given instruction's result id could have a SpecId 114 // decoration. 115 bool CanHaveSpecIdDecoration(const Instruction& inst) { 116 switch (inst.opcode()) { 117 case SpvOp::SpvOpSpecConstant: 118 case SpvOp::SpvOpSpecConstantFalse: 119 case SpvOp::SpvOpSpecConstantTrue: 120 return true; 121 default: 122 return false; 123 } 124 } 125 126 // Given a decoration group defining instruction that is decorated with SpecId 127 // decoration, finds the spec constant defining instruction which is the real 128 // target of the SpecId decoration. Returns the spec constant defining 129 // instruction if such an instruction is found, otherwise returns a nullptr. 130 Instruction* GetSpecIdTargetFromDecorationGroup( 131 const Instruction& decoration_group_defining_inst, 132 analysis::DefUseManager* def_use_mgr) { 133 // Find the OpGroupDecorate instruction which consumes the given decoration 134 // group. Note that the given decoration group has SpecId decoration, which 135 // is unique for different spec constants. So the decoration group cannot be 136 // consumed by different OpGroupDecorate instructions. Therefore we only need 137 // the first OpGroupDecoration instruction that uses the given decoration 138 // group. 139 Instruction* group_decorate_inst = nullptr; 140 if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst, 141 [&group_decorate_inst](Instruction* user) { 142 if (user->opcode() == 143 SpvOp::SpvOpGroupDecorate) { 144 group_decorate_inst = user; 145 return false; 146 } 147 return true; 148 })) 149 return nullptr; 150 151 // Scan through the target ids of the OpGroupDecorate instruction. There 152 // should be only one spec constant target consumes the SpecId decoration. 153 // If multiple target ids are presented in the OpGroupDecorate instruction, 154 // they must be the same one that defined by an eligible spec constant 155 // instruction. If the OpGroupDecorate instruction has different target ids 156 // or a target id is not defined by an eligible spec cosntant instruction, 157 // returns a nullptr. 158 Instruction* target_inst = nullptr; 159 for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) { 160 // All the operands of a OpGroupDecorate instruction should be of type 161 // SPV_OPERAND_TYPE_ID. 162 uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i); 163 Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id); 164 165 if (!candidate_inst) { 166 continue; 167 } 168 169 if (!target_inst) { 170 // If the spec constant target has not been found yet, check if the 171 // candidate instruction is the target. 172 if (CanHaveSpecIdDecoration(*candidate_inst)) { 173 target_inst = candidate_inst; 174 } else { 175 // Spec id decoration should not be applied on other instructions. 176 // TODO(qining): Emit an error message in the invalid case once the 177 // error handling is done. 178 return nullptr; 179 } 180 } else { 181 // If the spec constant target has been found, check if the candidate 182 // instruction is the same one as the target. The module is invalid if 183 // the candidate instruction is different with the found target. 184 // TODO(qining): Emit an error messaage in the invalid case once the 185 // error handling is done. 186 if (candidate_inst != target_inst) return nullptr; 187 } 188 } 189 return target_inst; 190 } 191 } // namespace 192 193 Pass::Status SetSpecConstantDefaultValuePass::Process() { 194 // The operand index of decoration target in an OpDecorate instruction. 195 const uint32_t kTargetIdOperandIndex = 0; 196 // The operand index of the decoration literal in an OpDecorate instruction. 197 const uint32_t kDecorationOperandIndex = 1; 198 // The operand index of Spec id literal value in an OpDecorate SpecId 199 // instruction. 200 const uint32_t kSpecIdLiteralOperandIndex = 2; 201 // The number of operands in an OpDecorate SpecId instruction. 202 const uint32_t kOpDecorateSpecIdNumOperands = 3; 203 // The in-operand index of the default value in a OpSpecConstant instruction. 204 const uint32_t kOpSpecConstantLiteralInOperandIndex = 0; 205 206 bool modified = false; 207 // Scan through all the annotation instructions to find 'OpDecorate SpecId' 208 // instructions. Then extract the decoration target of those instructions. 209 // The decoration targets should be spec constant defining instructions with 210 // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants 211 // will be used to look up their new default values in the mapping from 212 // spec id to new default value strings. Once a new default value string 213 // is found for a spec id, the string will be parsed according to the target 214 // spec constant type. The parsed value will be used to replace the original 215 // default value of the target spec constant. 216 for (Instruction& inst : context()->annotations()) { 217 // Only process 'OpDecorate SpecId' instructions 218 if (inst.opcode() != SpvOp::SpvOpDecorate) continue; 219 if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue; 220 if (inst.GetSingleWordInOperand(kDecorationOperandIndex) != 221 uint32_t(SpvDecoration::SpvDecorationSpecId)) { 222 continue; 223 } 224 225 // 'inst' is an OpDecorate SpecId instruction. 226 uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex); 227 uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex); 228 229 // Find the spec constant defining instruction. Note that the 230 // target_id might be a decoration group id. 231 Instruction* spec_inst = nullptr; 232 if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) { 233 if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) { 234 spec_inst = 235 GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr()); 236 } else { 237 spec_inst = target_inst; 238 } 239 } else { 240 continue; 241 } 242 if (!spec_inst) continue; 243 244 // Get the default value bit pattern for this spec id. 245 std::vector<uint32_t> bit_pattern; 246 247 if (spec_id_to_value_str_.size() != 0) { 248 // Search for the new string-form default value for this spec id. 249 auto iter = spec_id_to_value_str_.find(spec_id); 250 if (iter == spec_id_to_value_str_.end()) { 251 continue; 252 } 253 254 // Gets the string of the default value and parses it to bit pattern 255 // with the type of the spec constant. 256 const std::string& default_value_str = iter->second; 257 bit_pattern = ParseDefaultValueStr( 258 default_value_str.c_str(), 259 context()->get_type_mgr()->GetType(spec_inst->type_id())); 260 261 } else { 262 // Search for the new bit-pattern-form default value for this spec id. 263 auto iter = spec_id_to_value_bit_pattern_.find(spec_id); 264 if (iter == spec_id_to_value_bit_pattern_.end()) { 265 continue; 266 } 267 268 // Gets the bit-pattern of the default value from the map directly. 269 bit_pattern = ParseDefaultValueBitPattern( 270 iter->second, 271 context()->get_type_mgr()->GetType(spec_inst->type_id())); 272 } 273 274 if (bit_pattern.empty()) continue; 275 276 // Update the operand bit patterns of the spec constant defining 277 // instruction. 278 switch (spec_inst->opcode()) { 279 case SpvOp::SpvOpSpecConstant: 280 // If the new value is the same with the original value, no 281 // need to do anything. Otherwise update the operand words. 282 if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex) 283 .words != bit_pattern) { 284 spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex, 285 std::move(bit_pattern)); 286 modified = true; 287 } 288 break; 289 case SpvOp::SpvOpSpecConstantTrue: 290 // If the new value is also 'true', no need to change anything. 291 // Otherwise, set the opcode to OpSpecConstantFalse; 292 if (!static_cast<bool>(bit_pattern.front())) { 293 spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse); 294 modified = true; 295 } 296 break; 297 case SpvOp::SpvOpSpecConstantFalse: 298 // If the new value is also 'false', no need to change anything. 299 // Otherwise, set the opcode to OpSpecConstantTrue; 300 if (static_cast<bool>(bit_pattern.front())) { 301 spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue); 302 modified = true; 303 } 304 break; 305 default: 306 break; 307 } 308 // No need to update the DefUse manager, as this pass does not change any 309 // ids. 310 } 311 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; 312 } 313 314 // Returns true if the given char is ':', '\0' or considered as blank space 315 // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' '). 316 bool IsSeparator(char ch) { 317 return std::strchr(":\0", ch) || std::isspace(ch) != 0; 318 } 319 320 std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap> 321 SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) { 322 if (!str) return nullptr; 323 324 auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>(); 325 326 // The parsing loop, break when points to the end. 327 while (*str) { 328 // Find the spec id. 329 while (std::isspace(*str)) str++; // skip leading spaces. 330 const char* entry_begin = str; 331 while (!IsSeparator(*str)) str++; 332 const char* entry_end = str; 333 std::string spec_id_str(entry_begin, entry_end - entry_begin); 334 uint32_t spec_id = 0; 335 if (!ParseNumber(spec_id_str.c_str(), &spec_id)) { 336 // The spec id is not a valid uint32 number. 337 return nullptr; 338 } 339 auto iter = spec_id_to_value->find(spec_id); 340 if (iter != spec_id_to_value->end()) { 341 // Same spec id has been defined before 342 return nullptr; 343 } 344 // Find the ':', spaces between the spec id and the ':' are not allowed. 345 if (*str++ != ':') { 346 // ':' not found 347 return nullptr; 348 } 349 // Find the value string 350 const char* val_begin = str; 351 while (!IsSeparator(*str)) str++; 352 const char* val_end = str; 353 if (val_end == val_begin) { 354 // Value string is empty. 355 return nullptr; 356 } 357 // Update the mapping with spec id and value string. 358 (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin); 359 360 // Skip trailing spaces. 361 while (std::isspace(*str)) str++; 362 } 363 364 return spec_id_to_value; 365 } 366 367 } // namespace opt 368 } // namespace spvtools 369