1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/node_def.pb.h" 17 #include "tensorflow/core/lib/strings/str_util.h" 18 #include "tensorflow/core/platform/env.h" 19 #include "tensorflow/tools/graph_transforms/transform_utils.h" 20 21 namespace tensorflow { 22 namespace graph_transforms { 23 24 struct MinMaxRecord { 25 string name; 26 float min; 27 float max; 28 }; 29 30 // Try to parse a log file containing loosely-structured lines, some of which 31 // are the min/max logs we want. 32 Status ExtractMinMaxRecords(const string& log_file_name, 33 std::vector<MinMaxRecord>* records) { 34 string file_data; 35 TF_RETURN_IF_ERROR( 36 ReadFileToString(Env::Default(), log_file_name, &file_data)); 37 const string print_suffix("__print__"); 38 const string requant_prefix("__requant_min_max:"); 39 std::vector<string> file_lines = str_util::Split(file_data, '\n'); 40 for (const string& file_line : file_lines) { 41 // We expect to find a line with components separated by semicolons, so to 42 // start make sure that the basic structure is in place/ 43 StringPiece line(file_line); 44 if (!line.contains(print_suffix + ";" + requant_prefix)) { 45 continue; 46 } 47 std::vector<string> line_parts = str_util::Split(file_line, ';'); 48 if (line_parts.size() < 2) { 49 continue; 50 } 51 // Now we want to figure out which components have the name and min max 52 // values by scanning for the prefix we expect. 53 bool min_max_found = false; 54 int min_max_index; 55 for (int i = 1; i < line_parts.size(); ++i) { 56 StringPiece line_part(line_parts[i]); 57 if (line_part.starts_with(requant_prefix)) { 58 min_max_found = true; 59 min_max_index = i; 60 } 61 } 62 if (!min_max_found) { 63 continue; 64 } 65 // Finally we need to break out the values from the strings, and parse them 66 // into a form we can use. 67 string min_max_string = line_parts[min_max_index]; 68 std::vector<string> min_max_parts = str_util::Split(min_max_string, '['); 69 if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) { 70 continue; 71 } 72 string min_string = min_max_parts[1]; 73 std::vector<string> min_string_parts = str_util::Split(min_string, ']'); 74 if (min_string_parts.size() != 2) { 75 continue; 76 } 77 string min_number_string = min_string_parts[0]; 78 float min; 79 if (!strings::safe_strtof(min_number_string.c_str(), &min)) { 80 continue; 81 } 82 string max_string = min_max_parts[2]; 83 std::vector<string> max_string_parts = str_util::Split(max_string, ']'); 84 if (max_string_parts.size() != 2) { 85 continue; 86 } 87 string max_number_string = max_string_parts[0]; 88 float max; 89 if (!strings::safe_strtof(max_number_string.c_str(), &max)) { 90 continue; 91 } 92 StringPiece name_string = line_parts[min_max_index - 1]; 93 if (!name_string.ends_with(print_suffix)) { 94 continue; 95 } 96 string name = 97 name_string.substr(0, name_string.size() - print_suffix.size()) 98 .ToString(); 99 records->push_back({name, min, max}); 100 } 101 return Status::OK(); 102 } 103 104 // Uses the observed min/max values for requantization captured in a log file to 105 // replace costly RequantizationRange ops with simple Consts. 106 Status FreezeRequantizationRanges(const GraphDef& input_graph_def, 107 const TransformFuncContext& context, 108 GraphDef* output_graph_def) { 109 string min_max_log_file; 110 TF_RETURN_IF_ERROR( 111 context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file)); 112 if (min_max_log_file.empty()) { 113 return errors::InvalidArgument( 114 "You must pass a file name to min_max_log_file"); 115 } 116 float min_percentile; 117 TF_RETURN_IF_ERROR( 118 context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile)); 119 float max_percentile; 120 TF_RETURN_IF_ERROR( 121 context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile)); 122 123 std::vector<MinMaxRecord> records; 124 TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records)); 125 if (records.empty()) { 126 return errors::InvalidArgument( 127 "No min/max range logs were found in the log file"); 128 } 129 130 std::map<string, const NodeDef*> node_map; 131 MapNamesToNodes(input_graph_def, &node_map); 132 bool any_missing_nodes = false; 133 std::map<string, std::vector<MinMaxRecord>> records_by_node; 134 for (const MinMaxRecord& record : records) { 135 records_by_node[record.name].push_back(record); 136 if (!node_map.count(record.name)) { 137 any_missing_nodes = true; 138 LOG(WARNING) << "Node from log not found in graph: " << record.name; 139 } 140 } 141 if (any_missing_nodes) { 142 return errors::InvalidArgument( 143 "Nodes were found in the log file that aren't present in the graph"); 144 } 145 146 // Now find out the largest and smallest min/max values for the node. 147 std::map<string, std::pair<float, float>> range_for_nodes; 148 for (const auto& record_info : records_by_node) { 149 const string& name = record_info.first; 150 const std::vector<MinMaxRecord> records = record_info.second; 151 std::vector<float> mins; 152 std::vector<float> maxs; 153 for (const MinMaxRecord& record : records) { 154 mins.push_back(record.min); 155 maxs.push_back(record.max); 156 } 157 std::sort(mins.begin(), mins.end()); 158 std::sort(maxs.begin(), maxs.end()); 159 int min_index = std::round(mins.size() * (min_percentile / 100.0f)); 160 if (min_index < 0) { 161 min_index = 0; 162 } 163 int max_index = 164 std::round(maxs.size() * (1.0f - (max_percentile / 100.0f))); 165 if (max_index > (maxs.size() - 1)) { 166 max_index = maxs.size() - 1; 167 } 168 const float min = mins[min_index]; 169 const float max = maxs[max_index]; 170 range_for_nodes[name] = {min, max}; 171 } 172 std::map<string, string> inputs_to_rename; 173 GraphDef frozen_graph_def; 174 for (const NodeDef& node : input_graph_def.node()) { 175 if (range_for_nodes.count(node.name())) { 176 if (node.op() != "RequantizationRange") { 177 return errors::InvalidArgument( 178 "Node is expected to be a RequantizationRange op: ", node.name(), 179 ", but is: ", node.op()); 180 } 181 const float min_value = range_for_nodes.at(node.name()).first; 182 NodeDef* min_node = frozen_graph_def.mutable_node()->Add(); 183 min_node->set_op("Const"); 184 min_node->set_name(node.name() + "/frozen_min"); 185 SetNodeAttr("dtype", DT_FLOAT, min_node); 186 Tensor min_tensor(DT_FLOAT, {}); 187 min_tensor.flat<float>()(0) = min_value; 188 SetNodeTensorAttr<float>("value", min_tensor, min_node); 189 inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0"; 190 191 const float max_value = range_for_nodes.at(node.name()).second; 192 NodeDef* max_node = frozen_graph_def.mutable_node()->Add(); 193 max_node->set_op("Const"); 194 max_node->set_name(node.name() + "/frozen_max"); 195 SetNodeAttr("dtype", DT_FLOAT, max_node); 196 Tensor max_tensor(DT_FLOAT, {}); 197 max_tensor.flat<float>()(0) = max_value; 198 SetNodeTensorAttr<float>("value", max_tensor, max_node); 199 inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0"; 200 } else { 201 NodeDef* new_node = frozen_graph_def.mutable_node()->Add(); 202 *new_node = node; 203 } 204 } 205 return RenameNodeInputs(frozen_graph_def, inputs_to_rename, 206 std::unordered_set<string>(), output_graph_def); 207 } 208 209 REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges", 210 FreezeRequantizationRanges); 211 212 } // namespace graph_transforms 213 } // namespace tensorflow 214