Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/node_def.pb.h"
     17 #include "tensorflow/core/lib/strings/str_util.h"
     18 #include "tensorflow/core/platform/env.h"
     19 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     20 
     21 namespace tensorflow {
     22 namespace graph_transforms {
     23 
     24 struct MinMaxRecord {
     25   string name;
     26   float min;
     27   float max;
     28 };
     29 
     30 // Try to parse a log file containing loosely-structured lines, some of which
     31 // are the min/max logs we want.
     32 Status ExtractMinMaxRecords(const string& log_file_name,
     33                             std::vector<MinMaxRecord>* records) {
     34   string file_data;
     35   TF_RETURN_IF_ERROR(
     36       ReadFileToString(Env::Default(), log_file_name, &file_data));
     37   const string print_suffix("__print__");
     38   const string requant_prefix("__requant_min_max:");
     39   std::vector<string> file_lines = str_util::Split(file_data, '\n');
     40   for (const string& file_line : file_lines) {
     41     // We expect to find a line with components separated by semicolons, so to
     42     // start make sure that the basic structure is in place/
     43     StringPiece line(file_line);
     44     if (!line.contains(print_suffix + ";" + requant_prefix)) {
     45       continue;
     46     }
     47     std::vector<string> line_parts = str_util::Split(file_line, ';');
     48     if (line_parts.size() < 2) {
     49       continue;
     50     }
     51     // Now we want to figure out which components have the name and min max
     52     // values by scanning for the prefix we expect.
     53     bool min_max_found = false;
     54     int min_max_index;
     55     for (int i = 1; i < line_parts.size(); ++i) {
     56       StringPiece line_part(line_parts[i]);
     57       if (line_part.starts_with(requant_prefix)) {
     58         min_max_found = true;
     59         min_max_index = i;
     60       }
     61     }
     62     if (!min_max_found) {
     63       continue;
     64     }
     65     // Finally we need to break out the values from the strings, and parse them
     66     // into a form we can use.
     67     string min_max_string = line_parts[min_max_index];
     68     std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
     69     if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
     70       continue;
     71     }
     72     string min_string = min_max_parts[1];
     73     std::vector<string> min_string_parts = str_util::Split(min_string, ']');
     74     if (min_string_parts.size() != 2) {
     75       continue;
     76     }
     77     string min_number_string = min_string_parts[0];
     78     float min;
     79     if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
     80       continue;
     81     }
     82     string max_string = min_max_parts[2];
     83     std::vector<string> max_string_parts = str_util::Split(max_string, ']');
     84     if (max_string_parts.size() != 2) {
     85       continue;
     86     }
     87     string max_number_string = max_string_parts[0];
     88     float max;
     89     if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
     90       continue;
     91     }
     92     StringPiece name_string = line_parts[min_max_index - 1];
     93     if (!name_string.ends_with(print_suffix)) {
     94       continue;
     95     }
     96     string name =
     97         name_string.substr(0, name_string.size() - print_suffix.size())
     98             .ToString();
     99     records->push_back({name, min, max});
    100   }
    101   return Status::OK();
    102 }
    103 
    104 // Uses the observed min/max values for requantization captured in a log file to
    105 // replace costly RequantizationRange ops with simple Consts.
    106 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
    107                                   const TransformFuncContext& context,
    108                                   GraphDef* output_graph_def) {
    109   string min_max_log_file;
    110   TF_RETURN_IF_ERROR(
    111       context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
    112   if (min_max_log_file.empty()) {
    113     return errors::InvalidArgument(
    114         "You must pass a file name to min_max_log_file");
    115   }
    116   float min_percentile;
    117   TF_RETURN_IF_ERROR(
    118       context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
    119   float max_percentile;
    120   TF_RETURN_IF_ERROR(
    121       context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
    122 
    123   std::vector<MinMaxRecord> records;
    124   TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
    125   if (records.empty()) {
    126     return errors::InvalidArgument(
    127         "No min/max range logs were found in the log file");
    128   }
    129 
    130   std::map<string, const NodeDef*> node_map;
    131   MapNamesToNodes(input_graph_def, &node_map);
    132   bool any_missing_nodes = false;
    133   std::map<string, std::vector<MinMaxRecord>> records_by_node;
    134   for (const MinMaxRecord& record : records) {
    135     records_by_node[record.name].push_back(record);
    136     if (!node_map.count(record.name)) {
    137       any_missing_nodes = true;
    138       LOG(WARNING) << "Node from log not found in graph: " << record.name;
    139     }
    140   }
    141   if (any_missing_nodes) {
    142     return errors::InvalidArgument(
    143         "Nodes were found in the log file that aren't present in the graph");
    144   }
    145 
    146   // Now find out the largest and smallest min/max values for the node.
    147   std::map<string, std::pair<float, float>> range_for_nodes;
    148   for (const auto& record_info : records_by_node) {
    149     const string& name = record_info.first;
    150     const std::vector<MinMaxRecord> records = record_info.second;
    151     std::vector<float> mins;
    152     std::vector<float> maxs;
    153     for (const MinMaxRecord& record : records) {
    154       mins.push_back(record.min);
    155       maxs.push_back(record.max);
    156     }
    157     std::sort(mins.begin(), mins.end());
    158     std::sort(maxs.begin(), maxs.end());
    159     int min_index = std::round(mins.size() * (min_percentile / 100.0f));
    160     if (min_index < 0) {
    161       min_index = 0;
    162     }
    163     int max_index =
    164         std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
    165     if (max_index > (maxs.size() - 1)) {
    166       max_index = maxs.size() - 1;
    167     }
    168     const float min = mins[min_index];
    169     const float max = maxs[max_index];
    170     range_for_nodes[name] = {min, max};
    171   }
    172   std::map<string, string> inputs_to_rename;
    173   GraphDef frozen_graph_def;
    174   for (const NodeDef& node : input_graph_def.node()) {
    175     if (range_for_nodes.count(node.name())) {
    176       if (node.op() != "RequantizationRange") {
    177         return errors::InvalidArgument(
    178             "Node is expected to be a RequantizationRange op: ", node.name(),
    179             ", but is: ", node.op());
    180       }
    181       const float min_value = range_for_nodes.at(node.name()).first;
    182       NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
    183       min_node->set_op("Const");
    184       min_node->set_name(node.name() + "/frozen_min");
    185       SetNodeAttr("dtype", DT_FLOAT, min_node);
    186       Tensor min_tensor(DT_FLOAT, {});
    187       min_tensor.flat<float>()(0) = min_value;
    188       SetNodeTensorAttr<float>("value", min_tensor, min_node);
    189       inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
    190 
    191       const float max_value = range_for_nodes.at(node.name()).second;
    192       NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
    193       max_node->set_op("Const");
    194       max_node->set_name(node.name() + "/frozen_max");
    195       SetNodeAttr("dtype", DT_FLOAT, max_node);
    196       Tensor max_tensor(DT_FLOAT, {});
    197       max_tensor.flat<float>()(0) = max_value;
    198       SetNodeTensorAttr<float>("value", max_tensor, max_node);
    199       inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
    200     } else {
    201       NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
    202       *new_node = node;
    203     }
    204   }
    205   return RenameNodeInputs(frozen_graph_def, inputs_to_rename,
    206                           std::unordered_set<string>(), output_graph_def);
    207 }
    208 
    209 REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
    210                          FreezeRequantizationRanges);
    211 
    212 }  // namespace graph_transforms
    213 }  // namespace tensorflow
    214