1 // Copyright (c) 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "tools/stats/spirv_stats.h" 16 17 #include <cassert> 18 19 #include <algorithm> 20 #include <memory> 21 #include <string> 22 23 #include "source/diagnostic.h" 24 #include "source/enum_string_mapping.h" 25 #include "source/extensions.h" 26 #include "source/id_descriptor.h" 27 #include "source/instruction.h" 28 #include "source/opcode.h" 29 #include "source/operand.h" 30 #include "source/val/instruction.h" 31 #include "source/val/validate.h" 32 #include "source/val/validation_state.h" 33 #include "spirv-tools/libspirv.h" 34 35 namespace spvtools { 36 namespace stats { 37 namespace { 38 39 // Helper class for stats aggregation. Receives as in/out parameter. 40 // Constructs ValidationState and updates it by running validator for each 41 // instruction. 42 class StatsAggregator { 43 public: 44 StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state) 45 : stats_(in_out_stats), vstate_(state) {} 46 47 // Processes the instructions to collect stats. 48 void aggregate() { 49 const auto& instructions = vstate_->ordered_instructions(); 50 51 ++stats_->version_hist[vstate_->version()]; 52 ++stats_->generator_hist[vstate_->generator()]; 53 54 for (size_t i = 0; i < instructions.size(); ++i) { 55 const auto& inst = instructions[i]; 56 57 ProcessOpcode(&inst, i); 58 ProcessCapability(&inst); 59 ProcessExtension(&inst); 60 ProcessConstant(&inst); 61 } 62 } 63 64 // Collects OpCapability statistics. 65 void ProcessCapability(const val::Instruction* inst) { 66 if (inst->opcode() != SpvOpCapability) return; 67 const uint32_t capability = inst->word(inst->operands()[0].offset); 68 ++stats_->capability_hist[capability]; 69 } 70 71 // Collects OpExtension statistics. 72 void ProcessExtension(const val::Instruction* inst) { 73 if (inst->opcode() != SpvOpExtension) return; 74 const std::string extension = GetExtensionString(&inst->c_inst()); 75 ++stats_->extension_hist[extension]; 76 } 77 78 // Collects OpCode statistics. 79 void ProcessOpcode(const val::Instruction* inst, size_t idx) { 80 const SpvOp opcode = inst->opcode(); 81 ++stats_->opcode_hist[opcode]; 82 83 if (idx == 0) return; 84 85 --idx; 86 87 const auto& instructions = vstate_->ordered_instructions(); 88 89 auto step_it = stats_->opcode_markov_hist.begin(); 90 for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) { 91 auto& hist = (*step_it)[instructions[idx].opcode()]; 92 ++hist[opcode]; 93 94 if (idx == 0) break; 95 } 96 } 97 98 // Collects OpConstant statistics. 99 void ProcessConstant(const val::Instruction* inst) { 100 if (inst->opcode() != SpvOpConstant) return; 101 102 const uint32_t type_id = inst->GetOperandAs<uint32_t>(0); 103 const auto type_decl_it = vstate_->all_definitions().find(type_id); 104 assert(type_decl_it != vstate_->all_definitions().end()); 105 106 const val::Instruction& type_decl_inst = *type_decl_it->second; 107 const SpvOp type_op = type_decl_inst.opcode(); 108 if (type_op == SpvOpTypeInt) { 109 const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1); 110 const uint32_t is_signed = type_decl_inst.GetOperandAs<uint32_t>(2); 111 assert(is_signed == 0 || is_signed == 1); 112 if (bit_width == 16) { 113 if (is_signed) 114 ++stats_->s16_constant_hist[inst->GetOperandAs<int16_t>(2)]; 115 else 116 ++stats_->u16_constant_hist[inst->GetOperandAs<uint16_t>(2)]; 117 } else if (bit_width == 32) { 118 if (is_signed) 119 ++stats_->s32_constant_hist[inst->GetOperandAs<int32_t>(2)]; 120 else 121 ++stats_->u32_constant_hist[inst->GetOperandAs<uint32_t>(2)]; 122 } else if (bit_width == 64) { 123 if (is_signed) 124 ++stats_->s64_constant_hist[inst->GetOperandAs<int64_t>(2)]; 125 else 126 ++stats_->u64_constant_hist[inst->GetOperandAs<uint64_t>(2)]; 127 } else { 128 assert(false && "TypeInt bit width is not 16, 32 or 64"); 129 } 130 } else if (type_op == SpvOpTypeFloat) { 131 const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1); 132 if (bit_width == 32) { 133 ++stats_->f32_constant_hist[inst->GetOperandAs<float>(2)]; 134 } else if (bit_width == 64) { 135 ++stats_->f64_constant_hist[inst->GetOperandAs<double>(2)]; 136 } else { 137 assert(bit_width == 16); 138 } 139 } 140 } 141 142 private: 143 SpirvStats* stats_; 144 const val::ValidationState_t* vstate_; 145 IdDescriptorCollection id_descriptors_; 146 }; 147 148 } // namespace 149 150 spv_result_t AggregateStats(const spv_context context, const uint32_t* words, 151 const size_t num_words, spv_diagnostic* pDiagnostic, 152 SpirvStats* stats) { 153 std::unique_ptr<val::ValidationState_t> vstate; 154 spv_validator_options_t options; 155 spv_result_t result = ValidateBinaryAndKeepValidationState( 156 context, &options, words, num_words, pDiagnostic, &vstate); 157 if (result != SPV_SUCCESS) return result; 158 159 StatsAggregator stats_aggregator(stats, vstate.get()); 160 stats_aggregator.aggregate(); 161 return SPV_SUCCESS; 162 } 163 164 } // namespace stats 165 } // namespace spvtools 166