1 // Copyright (c) 2015-2016 The Khronos Group Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a 4 // copy of this software and/or associated documentation files (the 5 // "Materials"), to deal in the Materials without restriction, including 6 // without limitation the rights to use, copy, modify, merge, publish, 7 // distribute, sublicense, and/or sell copies of the Materials, and to 8 // permit persons to whom the Materials are furnished to do so, subject to 9 // the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included 12 // in all copies or substantial portions of the Materials. 13 // 14 // MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS 15 // KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS 16 // SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT 17 // https://www.khronos.org/registry/ 18 // 19 // THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 20 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 21 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 22 // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 23 // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 24 // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 25 // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. 26 27 #include "val/Function.h" 28 29 #include <cassert> 30 31 #include <algorithm> 32 #include <utility> 33 34 #include "val/BasicBlock.h" 35 #include "val/Construct.h" 36 #include "val/ValidationState.h" 37 38 using std::ignore; 39 using std::list; 40 using std::make_pair; 41 using std::pair; 42 using std::string; 43 using std::tie; 44 using std::vector; 45 46 namespace libspirv { 47 namespace { 48 49 void printDot(const BasicBlock& other, const ValidationState_t& module) { 50 string block_string; 51 if (other.successors()->empty()) { 52 block_string += "end "; 53 } else { 54 for (auto block : *other.successors()) { 55 block_string += module.getIdOrName(block->id()) + " "; 56 } 57 } 58 printf("%10s -> {%s\b}\n", module.getIdOrName(other.id()).c_str(), 59 block_string.c_str()); 60 } 61 } /// namespace 62 63 Function::Function(uint32_t function_id, uint32_t result_type_id, 64 SpvFunctionControlMask function_control, 65 uint32_t function_type_id, ValidationState_t& module) 66 : module_(module), 67 id_(function_id), 68 function_type_id_(function_type_id), 69 result_type_id_(result_type_id), 70 function_control_(function_control), 71 declaration_type_(FunctionDecl::kFunctionDeclUnknown), 72 end_has_been_registered_(false), 73 blocks_(), 74 current_block_(nullptr), 75 pseudo_entry_block_(0), 76 pseudo_exit_block_(kInvalidId), 77 pseudo_entry_blocks_({&pseudo_entry_block_}), 78 pseudo_exit_blocks_({&pseudo_exit_block_}), 79 cfg_constructs_(), 80 variable_ids_(), 81 parameter_ids_() {} 82 83 bool Function::IsFirstBlock(uint32_t block_id) const { 84 return !ordered_blocks_.empty() && *first_block() == block_id; 85 } 86 87 spv_result_t Function::RegisterFunctionParameter(uint32_t parameter_id, 88 uint32_t type_id) { 89 assert(module_.in_function_body() == true && 90 "RegisterFunctionParameter can only be called when parsing the binary " 91 "outside of another function"); 92 assert(current_block_ == nullptr && 93 "RegisterFunctionParameter can only be called when parsing the binary " 94 "ouside of a block"); 95 // TODO(umar): Validate function parameter type order and count 96 // TODO(umar): Use these variables to validate parameter type 97 (void)parameter_id; 98 (void)type_id; 99 return SPV_SUCCESS; 100 } 101 102 spv_result_t Function::RegisterLoopMerge(uint32_t merge_id, 103 uint32_t continue_id) { 104 RegisterBlock(merge_id, false); 105 RegisterBlock(continue_id, false); 106 BasicBlock& merge_block = blocks_.at(merge_id); 107 BasicBlock& continue_block = blocks_.at(continue_id); 108 assert(current_block_ && 109 "RegisterLoopMerge must be called when called within a block"); 110 111 current_block_->set_type(kBlockTypeLoop); 112 merge_block.set_type(kBlockTypeMerge); 113 continue_block.set_type(kBlockTypeContinue); 114 cfg_constructs_.emplace_back(ConstructType::kLoop, current_block_, 115 &merge_block); 116 Construct& loop_construct = cfg_constructs_.back(); 117 cfg_constructs_.emplace_back(ConstructType::kContinue, &continue_block); 118 Construct& continue_construct = cfg_constructs_.back(); 119 continue_construct.set_corresponding_constructs({&loop_construct}); 120 loop_construct.set_corresponding_constructs({&continue_construct}); 121 122 return SPV_SUCCESS; 123 } 124 125 spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { 126 RegisterBlock(merge_id, false); 127 BasicBlock& merge_block = blocks_.at(merge_id); 128 current_block_->set_type(kBlockTypeHeader); 129 merge_block.set_type(kBlockTypeMerge); 130 131 cfg_constructs_.emplace_back(ConstructType::kSelection, current_block(), 132 &merge_block); 133 return SPV_SUCCESS; 134 } 135 136 void Function::PrintDotGraph() const { 137 if (first_block()) { 138 string func_name(module_.getIdOrName(id_)); 139 printf("digraph %s {\n", func_name.c_str()); 140 PrintBlocks(); 141 printf("}\n"); 142 } 143 } 144 145 void Function::PrintBlocks() const { 146 if (first_block()) { 147 printf("%10s -> %s\n", module_.getIdOrName(id_).c_str(), 148 module_.getIdOrName(first_block()->id()).c_str()); 149 for (const auto& block : blocks_) { 150 printDot(block.second, module_); 151 } 152 } 153 } 154 155 spv_result_t Function::RegisterSetFunctionDeclType(FunctionDecl type) { 156 assert(declaration_type_ == FunctionDecl::kFunctionDeclUnknown); 157 declaration_type_ = type; 158 return SPV_SUCCESS; 159 } 160 161 spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) { 162 assert(module_.in_function_body() == true && 163 "RegisterBlocks can only be called when parsing a binary inside of a " 164 "function"); 165 assert(module_.current_layout_section() != 166 ModuleLayoutSection::kLayoutFunctionDeclarations && 167 "RegisterBlocks cannot be called within a function declaration"); 168 assert( 169 declaration_type_ == FunctionDecl::kFunctionDeclDefinition && 170 "RegisterBlocks can only be called after declaration_type_ is defined"); 171 172 std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block; 173 bool success = false; 174 tie(inserted_block, success) = 175 blocks_.insert({block_id, BasicBlock(block_id)}); 176 if (is_definition) { // new block definition 177 assert(current_block_ == nullptr && 178 "Register Block can only be called when parsing a binary outside of " 179 "a BasicBlock"); 180 181 undefined_blocks_.erase(block_id); 182 current_block_ = &inserted_block->second; 183 ordered_blocks_.push_back(current_block_); 184 if (IsFirstBlock(block_id)) current_block_->set_reachable(true); 185 } else if (success) { // Block doesn't exsist but this is not a definition 186 undefined_blocks_.insert(block_id); 187 } 188 189 return SPV_SUCCESS; 190 } 191 192 void Function::RegisterBlockEnd(vector<uint32_t> next_list, 193 SpvOp branch_instruction) { 194 assert(module_.in_function_body() == true && 195 "RegisterBlockEnd can only be called when parsing a binary in a " 196 "function"); 197 assert( 198 current_block_ && 199 "RegisterBlockEnd can only be called when parsing a binary in a block"); 200 201 vector<BasicBlock*> next_blocks; 202 next_blocks.reserve(next_list.size()); 203 204 std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block; 205 bool success; 206 for (uint32_t successor_id : next_list) { 207 tie(inserted_block, success) = 208 blocks_.insert({successor_id, BasicBlock(successor_id)}); 209 if (success) { 210 undefined_blocks_.insert(successor_id); 211 } 212 next_blocks.push_back(&inserted_block->second); 213 } 214 215 current_block_->RegisterBranchInstruction(branch_instruction); 216 current_block_->RegisterSuccessors(next_blocks); 217 current_block_ = nullptr; 218 return; 219 } 220 221 void Function::RegisterFunctionEnd() { 222 if (!end_has_been_registered_) { 223 end_has_been_registered_ = true; 224 225 // Compute the successors of the pseudo-entry block, and 226 // the predecessors of the pseudo exit block. 227 vector<BasicBlock*> sources; 228 vector<BasicBlock*> sinks; 229 for (const auto b : ordered_blocks_) { 230 if (b->predecessors()->empty()) sources.push_back(b); 231 if (b->successors()->empty()) sinks.push_back(b); 232 } 233 pseudo_entry_block_.SetSuccessorsUnsafe(std::move(sources)); 234 pseudo_exit_block_.SetPredecessorsUnsafe(std::move(sinks)); 235 } 236 } 237 238 size_t Function::block_count() const { return blocks_.size(); } 239 240 size_t Function::undefined_block_count() const { 241 return undefined_blocks_.size(); 242 } 243 244 const vector<BasicBlock*>& Function::ordered_blocks() const { 245 return ordered_blocks_; 246 } 247 vector<BasicBlock*>& Function::ordered_blocks() { return ordered_blocks_; } 248 249 const BasicBlock* Function::current_block() const { return current_block_; } 250 BasicBlock* Function::current_block() { return current_block_; } 251 252 BasicBlock* Function::pseudo_entry_block() { return &pseudo_entry_block_; } 253 const BasicBlock* Function::pseudo_entry_block() const { 254 return &pseudo_entry_block_; 255 } 256 257 BasicBlock* Function::pseudo_exit_block() { return &pseudo_exit_block_; } 258 const BasicBlock* Function::pseudo_exit_block() const { 259 return &pseudo_exit_block_; 260 } 261 262 const list<Construct>& Function::constructs() const { return cfg_constructs_; } 263 list<Construct>& Function::constructs() { return cfg_constructs_; } 264 265 const BasicBlock* Function::first_block() const { 266 if (ordered_blocks_.empty()) return nullptr; 267 return ordered_blocks_[0]; 268 } 269 BasicBlock* Function::first_block() { 270 if (ordered_blocks_.empty()) return nullptr; 271 return ordered_blocks_[0]; 272 } 273 274 bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const { 275 bool ret = false; 276 const BasicBlock* block; 277 tie(block, ignore) = GetBlock(merge_block_id); 278 if (block) { 279 ret = block->is_type(type); 280 } 281 return ret; 282 } 283 284 pair<const BasicBlock*, bool> Function::GetBlock(uint32_t block_id) const { 285 const auto b = blocks_.find(block_id); 286 if (b != end(blocks_)) { 287 const BasicBlock* block = &(b->second); 288 bool defined = 289 undefined_blocks_.find(block->id()) == end(undefined_blocks_); 290 return make_pair(block, defined); 291 } else { 292 return make_pair(nullptr, false); 293 } 294 } 295 296 pair<BasicBlock*, bool> Function::GetBlock(uint32_t block_id) { 297 const BasicBlock* out; 298 bool defined; 299 tie(out, defined) = const_cast<const Function*>(this)->GetBlock(block_id); 300 return make_pair(const_cast<BasicBlock*>(out), defined); 301 } 302 } /// namespace libspirv 303