Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #define EIGEN_USE_THREADS
     18 #include "tensorflow/core/common_runtime/constant_folding.h"
     19 #include "tensorflow/core/common_runtime/threadpool_device.h"
     20 #include "tensorflow/core/graph/graph_constructor.h"
     21 #include "tensorflow/core/graph/node_builder.h"
     22 #include "tensorflow/core/graph/subgraph.h"
     23 #include "tensorflow/core/kernels/quantization_utils.h"
     24 #include "tensorflow/core/platform/init_main.h"
     25 #include "tensorflow/core/public/session.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     28 namespace tensorflow {
     29 namespace graph_transforms {
     31 // Holds the information we need to translate from a float version of this op
     32 // into the quantized equivalent.
     33 struct QuantizedOpInfo {
     34   // The name of the float op.
     35   string float_name;
     36   // Which attributes to copy directly over.
     37   std::vector<string> attrs_to_copy;
     38   // Extra data type attributes we need to set.
     39   std::vector<std::pair<string, DataType>> dtypes_to_set;
     40   // What depth of inputs the op can read in.
     41   DataType input_bit_depth;
     42   // The depth of the op's quantized outputs.
     43   DataType output_bit_depth;
     44   // Which inputs (e.g. shapes) aren't involved in the quantization process.
     45   std::set<int32> unquantized_inputs;
     46   // How the outputs are arranged, either
     47   // [input0, input1, min0, max0, min1, max1] for contiguous, or
     48   // [input0, input1, min0, min1, max0, max1] for separate.
     49   // The separate order is needed because it's the only way to specify unknown
     50   // numbers of inputs for ops like Concat.
     51   enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
     52 };
     54 // Every op that has a quantized equivalent should be listed here, so that the
     55 // conversion process can transform them.
     56 const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
     57   static const std::vector<QuantizedOpInfo> op_list = {
     58       {"Add",
     59        {},
     60        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
     61        DT_QUINT8,
     62        DT_QINT32,
     63        {},
     64        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
     65       {"AvgPool",
     66        {"ksize", "strides", "padding"},
     67        {{"T", DT_QUINT8}},
     68        DT_QUINT8,
     69        DT_QUINT8,
     70        {},
     71        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
     72       {"BiasAdd",
     73        {},
     74        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
     75        DT_QUINT8,
     76        DT_QINT32,
     77        {},
     78        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
     79       {"Concat",
     80        {"N"},
     81        {{"T", DT_QUINT8}},
     82        DT_QUINT8,
     83        DT_QUINT8,
     84        {0},
     85        QuantizedOpInfo::SEPARATE_MIN_MAX},
     86       {"Conv2D",
     87        {"strides", "padding"},
     88        {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
     89        DT_QUINT8,
     90        DT_QINT32,
     91        {},
     92        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
     93       {"MatMul",
     94        {"transpose_a", "transpose_b"},
     95        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
     96        DT_QUINT8,
     97        DT_QINT32,
     98        {},
     99        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    100       {"MaxPool",
    101        {"ksize", "strides", "padding"},
    102        {{"T", DT_QUINT8}},
    103        DT_QUINT8,
    104        DT_QUINT8,
    105        {},
    106        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    107       {"Mul",
    108        {},
    109        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
    110        DT_QUINT8,
    111        DT_QINT32,
    112        {},
    113        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    114       {"Relu",
    115        {},
    116        {{"Tinput", DT_QUINT8}},
    117        DT_QUINT8,
    118        DT_QUINT8,
    119        {},
    120        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    121       {"ResizeBilinear",
    122        {"align_corners"},
    123        {{"T", DT_QUINT8}},
    124        DT_QUINT8,
    125        DT_QUINT8,
    126        {1},
    127        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    128       {"Relu6",
    129        {},
    130        {{"Tinput", DT_QUINT8}},
    131        DT_QUINT8,
    132        DT_QUINT8,
    133        {},
    134        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    135       {"Reshape",
    136        {},
    137        {{"T", DT_QUINT8}},
    138        DT_QUINT8,
    139        DT_QUINT8,
    140        {1},
    141        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
    142   };
    143   return op_list;
    144 }
    146 namespace {
    147 // Replaces invalid characters in input names to get a unique node name.
    148 string UniqueNodeNameFromInput(const string& input_name) {
    149   string prefix;
    150   string node_name;
    151   string suffix;
    152   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
    153   string result;
    154   if (prefix == "^") {
    155     result += "__hat__";
    156   }
    157   result += node_name;
    158   if (!suffix.empty()) {
    159     result += "__port__" + suffix.substr(1, suffix.size() - 1);
    160   }
    161   return result;
    162 }
    164 // Pulls two float values from the named parameters, with a lot of checking.
    165 Status ExtractRangeFromParams(const TransformFuncContext& context,
    166                               const string& min_name, const string& max_name,
    167                               float* min_value, float* max_value,
    168                               bool* has_range) {
    169   // See if we've been given quantized inputs with a known range.
    170   const bool has_min = (context.params.count(min_name) != 0);
    171   const bool has_max = (context.params.count(max_name) != 0);
    172   *has_range = (has_min || has_max);
    173   if (!*has_range) {
    174     return Status::OK();
    175   }
    176   if (!has_min || !has_max) {
    177     return errors::InvalidArgument("You must pass both ", min_name, " and ",
    178                                    max_name, " into quantize_nodes");
    179   }
    180   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
    181   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
    182   return Status::OK();
    183 }
    185 }  // namespace
    187 // Analyzes all the nodes in the graph to figure out which ones are duplicates
    188 // apart from their names. This commonly includes identical Const nodes, but can
    189 // also be simple operations that are repeated on multiple outputs of a
    190 // particular node. The complexity is managed using a hash function that avoids
    191 // the need for any O(n^2) algorithms when identifying duplicates.
    192 Status MergeDuplicateNodes(const GraphDef& input_graph_def,
    193                            const TransformFuncContext& context,
    194                            GraphDef* output_graph_def) {
    195   // Make sure we can look up inputs and outputs quickly.
    196   std::set<string> input_names(context.input_names.begin(),
    197                                context.input_names.end());
    198   std::set<string> output_names(context.output_names.begin(),
    199                                 context.output_names.end());
    200   GraphDef current_graph_def = input_graph_def;
    201   // Keep running the merging until no more duplicates are found.
    202   bool any_duplicates_found;
    203   do {
    204     any_duplicates_found = false;
    205     // First arrange all of the nodes by a hash of their contents.
    206     std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
    207     for (const NodeDef& node : current_graph_def.node()) {
    208       NodeDef nameless_node = node;
    209       // The name matters if it's being used as an input or output node,
    210       // otherwise ignore it when looking for duplicates.
    211       if (!input_names.count(node.name()) && !output_names.count(node.name())) {
    212         nameless_node.set_name("");
    213       }
    214       const uint64 hash = HashNodeDef(nameless_node);
    215       hashed_nodes[hash].push_back(&node);
    216     }
    217     // If we have multiple nodes with the same hash, then we know they're
    218     // duplicates and can be removed, unless they're stateful.
    219     std::map<string, string> inputs_to_rename;
    220     GraphDef merged_graph_def;
    221     for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info :
    222          hashed_nodes) {
    223       const std::vector<const NodeDef*>& hash_node_list =
    224           hashed_node_info.second;
    225       for (int i = 0; i < hash_node_list.size(); ++i) {
    226         const NodeDef* current_node = hash_node_list[i];
    227         const OpDef* op_def = nullptr;
    228         TF_RETURN_IF_ERROR(
    229             OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
    230         const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
    231         if (is_duplicate) {
    232           const string original_name = hash_node_list[0]->name();
    233           inputs_to_rename[current_node->name() + ":*"] = original_name;
    234           any_duplicates_found = true;
    235         } else {
    236           NodeDef* new_node = merged_graph_def.mutable_node()->Add();
    237           *new_node = *current_node;
    238         }
    239       }
    240     }
    241     // Update the graph so that any nodes that referred to removed inputs now
    242     // pull from the remaining duplicate.
    243     TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename,
    244                                         std::unordered_set<string>(),
    245                                         &current_graph_def));
    246   } while (any_duplicates_found);
    248   *output_graph_def = current_graph_def;
    250   return Status::OK();
    251 }
    253 // Looks for the patterns that indicate there are two eight-bit ops feeding into
    254 // each other, separated by a conversion up to float and back again. These occur
    255 // during the initial conversion of ops to their quantized forms. Because we're
    256 // only looking at an individual op in that phase and don't know if its inputs
    257 // and outputs are eight-bit-capable, we start by converting the actual op into
    258 // quantized form, but add float conversions before and after. This pass gets
    259 // rid of those conversions if it turns out we do have adjacent ops capable of
    260 // eight-bit processing.
    261 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
    262                                     const TransformFuncContext& context,
    263                                     GraphDef* output_graph_def) {
    264   std::set<string> graph_outputs;
    265   for (const string& output_name : context.output_names) {
    266     graph_outputs.insert(NodeNameFromInput(output_name));
    267   }
    268   std::map<string, string> inputs_to_rename;
    269   GraphDef replaced_graph_def;
    270   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    271       input_graph_def,  // clang-format off
    272       {"QuantizeV2",
    273         {
    274           {"Dequantize"},
    275           {"Min"},
    276           {"Max"},
    277         }
    278       },  // clang-format on
    279       [&inputs_to_rename, &graph_outputs](const NodeMatch& match,
    280                                           const std::set<string>& input_nodes,
    281                                           const std::set<string>& output_nodes,
    282                                           std::vector<NodeDef>* new_nodes) {
    283         const NodeDef& quantize_node = match.node;
    284         const NodeDef& dequantize_node = match.inputs[0].node;
    285         inputs_to_rename[quantize_node.name() + ":0"] =
    286             dequantize_node.input(0);
    287         inputs_to_rename[quantize_node.name() + ":1"] =
    288             dequantize_node.input(1);
    289         inputs_to_rename[quantize_node.name() + ":2"] =
    290             dequantize_node.input(2);
    292         // Are other sub-graphs using the float intermediate result? If so,
    293         // preserve it, but the input renaming still rewires the eight-bit ops
    294         // so they don't go through float.
    295         if (output_nodes.count(dequantize_node.name()) ||
    296             graph_outputs.count(dequantize_node.name())) {
    297           CopyOriginalMatch(match, new_nodes);
    298         }
    300         return Status::OK();
    301       },
    302       {true}, &replaced_graph_def));
    304   return RenameNodeInputs(replaced_graph_def, inputs_to_rename,
    305                           std::unordered_set<string>(), output_graph_def);
    306 }
    308 // If the user has passed in the input_min and input_max args, then we need to
    309 // convert any input placeholders from float to eight bit, so quantized inputs
    310 // can be fed directly into the graph.
    311 Status QuantizePlaceholders(const GraphDef& input_graph_def,
    312                             const TransformFuncContext& context,
    313                             GraphDef* output_graph_def) {
    314   float input_min;
    315   float input_max;
    316   bool has_input_range;
    317   TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
    318                                             &input_min, &input_max,
    319                                             &has_input_range));
    320   if (!has_input_range) {
    321     *output_graph_def = input_graph_def;
    322     return Status::OK();
    323   }
    324   std::map<string, string> inputs_to_rename_first_pass;
    325   std::map<string, string> inputs_to_rename_second_pass;
    326   GraphDef placeholder_graph_def;
    327   placeholder_graph_def.Clear();
    328   for (const NodeDef& node : input_graph_def.node()) {
    329     if (node.op() != "Placeholder") {
    330       *(placeholder_graph_def.mutable_node()->Add()) = node;
    331     } else {
    332       string namespace_prefix = node.name() + "_eightbit";
    334       NodeDef quantized_placeholder;
    335       quantized_placeholder = node;
    336       SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
    337       *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder;
    339       NodeDef min_node;
    340       min_node.set_op("Const");
    341       min_node.set_name(namespace_prefix + "/min");
    342       SetNodeAttr("dtype", DT_FLOAT, &min_node);
    343       Tensor min_tensor(DT_FLOAT, {});
    344       min_tensor.flat<float>()(0) = input_min;
    345       SetNodeTensorAttr<float>("value", min_tensor, &min_node);
    346       *(placeholder_graph_def.mutable_node()->Add()) = min_node;
    348       NodeDef max_node;
    349       max_node.set_op("Const");
    350       max_node.set_name(namespace_prefix + "/max");
    351       SetNodeAttr("dtype", DT_FLOAT, &max_node);
    352       Tensor max_tensor(DT_FLOAT, {});
    353       max_tensor.flat<float>()(0) = input_max;
    354       SetNodeTensorAttr<float>("value", max_tensor, &max_node);
    355       *(placeholder_graph_def.mutable_node()->Add()) = max_node;
    357       const string rename_suffix = "__RENAMED_PLACEHOLDER__";
    358       NodeDef dequantize_node;
    359       dequantize_node.set_op("Dequantize");
    360       dequantize_node.set_name(namespace_prefix + "/dequantize");
    361       SetNodeAttr("T", DT_QUINT8, &dequantize_node);
    362       SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
    363       AddNodeInput(node.name() + rename_suffix, &dequantize_node);
    364       AddNodeInput(min_node.name(), &dequantize_node);
    365       AddNodeInput(max_node.name(), &dequantize_node);
    366       *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node;
    368       // First make sure that any internal references to the old placeholder
    369       // now point to the dequantize result.
    370       inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
    371       // Then fix up the dequantize op so that it really points to the
    372       // placeholder.
    373       inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
    374     }
    375   }
    377   GraphDef first_pass_graph_def;
    379       RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
    380                        std::unordered_set<string>(), &first_pass_graph_def));
    382       RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
    383                        std::unordered_set<string>(), output_graph_def));
    385   return Status::OK();
    386 }
    388 // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
    389 // an activation layer. To use these during inference, this pass converts those
    390 // ops into Requantizes with the trained min/maxes as constant inputs.
    391 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
    392                                      const TransformFuncContext& context,
    393                                      GraphDef* output_graph_def) {
    394   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    395       input_graph_def,  // clang-format off
    396       {"FakeQuantWithMinMaxVars",
    397         {
    398           {"*"},
    399           {"Const"},
    400           {"Const"},
    401         }
    402       },  // clang-format on
    403       [](const NodeMatch& match, const std::set<string>& input_nodes,
    404          const std::set<string>& output_nodes,
    405          std::vector<NodeDef>* new_nodes) {
    406         const NodeDef& fake_quant_node = match.node;
    407         const NodeDef& original_op_node = match.inputs[0].node;
    408         const NodeDef& fake_quant_min_node = match.inputs[1].node;
    409         const NodeDef& fake_quant_max_node = match.inputs[2].node;
    411         string namespace_prefix = fake_quant_node.name() + "_eightbit";
    413         new_nodes->push_back(original_op_node);
    414         new_nodes->push_back(fake_quant_min_node);
    415         new_nodes->push_back(fake_quant_max_node);
    417         NodeDef quantize_node;
    418         quantize_node.set_op("QuantizeV2");
    419         quantize_node.set_name(namespace_prefix + "/quantize");
    420         SetNodeAttr("T", DT_QINT32, &quantize_node);
    421         SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
    422         AddNodeInput(fake_quant_node.input(0), &quantize_node);
    423         AddNodeInput(fake_quant_min_node.name(), &quantize_node);
    424         AddNodeInput(fake_quant_max_node.name(), &quantize_node);
    425         new_nodes->push_back(quantize_node);
    427         NodeDef requantize_node;
    428         requantize_node.set_op("Requantize");
    429         requantize_node.set_name(namespace_prefix + "/requantize");
    430         SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
    431         SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
    432         AddNodeInput(quantize_node.name() + ":0", &requantize_node);
    433         AddNodeInput(quantize_node.name() + ":1", &requantize_node);
    434         AddNodeInput(quantize_node.name() + ":2", &requantize_node);
    435         AddNodeInput(fake_quant_min_node.name(), &requantize_node);
    436         AddNodeInput(fake_quant_max_node.name(), &requantize_node);
    437         new_nodes->push_back(requantize_node);
    439         // Convert the 8-bit result back into float for the final output.
    440         NodeDef dequantize_node;
    441         dequantize_node.set_op("Dequantize");
    442         dequantize_node.set_name(fake_quant_node.name());
    443         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
    444         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
    445         AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
    446         AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
    447         AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
    448         new_nodes->push_back(dequantize_node);
    450         return Status::OK();
    451       },
    452       {}, output_graph_def));
    454   return Status::OK();
    455 }
    457 // We always generate Requantize ops driven by dynamic RequantizationRange
    458 // calculations when we produce quantized ops like Conv2D or BiasAdd with
    459 // 32-bit results. If there were FakeQuant ops already for those activation
    460 // layers, then there will be a later Requantize op with constant min/max
    461 // inputs, which is preferable for fast inference. This pass looks for those
    462 // later Requantize ops, and replaces the dynamic version with them.
    463 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
    464                                 const TransformFuncContext& context,
    465                                 GraphDef* output_graph_def) {
    466   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    467       input_graph_def,  // clang-format off
    468       {"Requantize",
    469         {
    470           {"QuantizeV2",
    471             {
    472               {"Dequantize",
    473                 {
    474                   {"Requantize",
    475                     {
    476                       {"*"},
    477                       {"*"},
    478                       {"*"},
    479                       {"RequantizationRange"},
    480                       {"RequantizationRange"},
    481                     }
    482                   },
    483                   {"Requantize"},
    484                   {"Requantize"},
    485                 }
    486               },
    487               {"Const"},
    488               {"Const"},
    489             },
    490           },
    491           {"QuantizeV2"},
    492           {"QuantizeV2"},
    493           {"Const"},
    494           {"Const"},
    495         }
    496       },  // clang-format on
    497       [](const NodeMatch& match, const std::set<string>& input_nodes,
    498          const std::set<string>& output_nodes,
    499          std::vector<NodeDef>* new_nodes) {
    500         const NodeDef& fake_requantize_node = match.node;
    501         const NodeDef& original_op_node =
    502             match.inputs[0].inputs[0].inputs[0].inputs[0].node;
    503         const NodeDef& fake_requantize_min_node = match.inputs[3].node;
    504         const NodeDef& fake_requantize_max_node = match.inputs[4].node;
    506         new_nodes->push_back(original_op_node);
    507         new_nodes->push_back(fake_requantize_min_node);
    508         new_nodes->push_back(fake_requantize_max_node);
    510         NodeDef requantize_node;
    511         requantize_node = fake_requantize_node;
    512         requantize_node.mutable_input()->Clear();
    513         AddNodeInput(original_op_node.name() + ":0", &requantize_node);
    514         AddNodeInput(original_op_node.name() + ":1", &requantize_node);
    515         AddNodeInput(original_op_node.name() + ":2", &requantize_node);
    516         AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
    517         AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
    518         new_nodes->push_back(requantize_node);
    520         return Status::OK();
    521       },
    522       {}, output_graph_def));
    524   return Status::OK();
    525 }
    527 // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
    528 // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
    529 // op that we want to apply the trained constant conversions to. This pass tries
    530 // to move FakeQuant ops up the input chain, so they're as close as possible to
    531 // the 32-bit conversion, and so can be easily merged into the automatic dynamic
    532 // Requantizes.
    533 Status HoistFakeQuants(const GraphDef& input_graph_def,
    534                        const TransformFuncContext& context,
    535                        GraphDef* output_graph_def) {
    536   GraphDef current_graph_def = input_graph_def;
    537   const int max_depth = 3;
    538   for (int depth = max_depth; depth > 0; --depth) {
    539     OpTypePattern pattern = {"*"};
    540     for (int i = 0; i < depth; ++i) {
    541       pattern = {"*", {pattern}};
    542     }
    543     pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
    544     GraphDef hoisted_graph_def;
    545     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    546         current_graph_def, pattern,
    547         [depth](const NodeMatch& match, const std::set<string>& input_nodes,
    548                 const std::set<string>& output_nodes,
    549                 std::vector<NodeDef>* new_nodes) {
    550           const NodeDef& fake_quant_node = match.node;
    551           const NodeDef& fake_quant_min_node = match.inputs[1].node;
    552           const NodeDef& fake_quant_max_node = match.inputs[2].node;
    553           std::vector<NodeDef> linear_nodes;
    554           NodeMatch current_match = match;
    555           for (int i = 0; i <= depth; ++i) {
    556             linear_nodes.push_back(current_match.inputs[0].node);
    557             current_match = current_match.inputs[0];
    558           }
    559           NodeDef new_fake_quant_node;
    560           new_fake_quant_node = fake_quant_node;
    561           new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
    562           new_fake_quant_node.set_input(
    563               0, linear_nodes[linear_nodes.size() - 2].input(0));
    564           new_nodes->push_back(new_fake_quant_node);
    566           new_nodes->push_back(fake_quant_min_node);
    567           new_nodes->push_back(fake_quant_max_node);
    569           linear_nodes[linear_nodes.size() - 2].set_input(
    570               0, new_fake_quant_node.name());
    571           linear_nodes.front().set_name(fake_quant_node.name());
    572           for (const NodeDef& linear_node : linear_nodes) {
    573             new_nodes->push_back(linear_node);
    574           }
    576           return Status::OK();
    577         },
    578         {}, &hoisted_graph_def));
    579     current_graph_def = hoisted_graph_def;
    580   }
    581   *output_graph_def = current_graph_def;
    583   return Status::OK();
    584 }
    586 // Converts any float ops that have eight-bit equivalents into their quantized
    587 // forms, so that as much calculation as possible is done in the lower-precision
    588 // format.
    589 Status QuantizeNodes(const GraphDef& input_graph_def,
    590                      const TransformFuncContext& context,
    591                      GraphDef* output_graph_def) {
    592   // Loop through all of the quantizable op types, and replace any occurrences
    593   // with equivalent sub-graphs with quantized ops at their core. For example
    594   // this one-input operation:
    595   //
    596   //            Input(float)
    597   //                |
    598   //                v
    599   //            Operation
    600   //                |
    601   //                v
    602   //             (float)
    603   //
    604   // Will be turned into it's quantized equivalent:
    605   //
    606   //      Input(float)          ReshapeDims
    607   //         +------v v-------------+
    608   //         |    Reshape
    609   //         |      |
    610   //         |      |          ReductionDims
    611   //         |      +-----+         |
    612   //         |      | +---c---------+
    613   //         |      v v   v v-------+
    614   //         |      Min   Max
    615   //         |  +----+      |
    616   //         v  v  v--------+
    617   //        Quantize
    618   //            |
    619   //            v
    620   //     QuantizedOperation
    621   //        |   |   |
    622   //        v   v   v
    623   //        Dequantize
    624   //            |
    625   //            v
    626   //         (float)
    627   //
    628   // This keeps the inputs and outputs visible to the rest of the graph in
    629   // float
    630   // and converts them down to quantized buffers internally for the
    631   // computation.
    632   // The result will end up with a lot of redundant dequantize/quantize pairs
    633   // between adjacent quantized ops, but a later pass removes these where it
    634   // can.
    636   std::set<string> ops_to_ignore;
    637   if (context.params.count("ignore_op") > 0) {
    638     for (const string& name : context.params.at("ignore_op")) {
    639       ops_to_ignore.insert(name);
    640     }
    641   }
    643   const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
    644   string op_pattern;
    645   bool is_first = true;
    646   std::map<string, QuantizedOpInfo> op_map;
    647   for (const QuantizedOpInfo& op_info : op_list) {
    648     if (ops_to_ignore.count(op_info.float_name) == 0) {
    649       strings::StrAppend(&op_pattern, (is_first ? "" : "|"),
    650                          op_info.float_name);
    651       op_map.insert({op_info.float_name, op_info});
    652       is_first = false;
    653     }
    654   }
    656   // If input_min and input max have been passed in, then we convert all float
    657   // Placeholder nodes into quantized versions, with the supplied values as
    658   // their range.
    659   GraphDef placeholder_graph_def;
    661       QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
    662   TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
    664   // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
    665   // operations like Relu or MaxPool, move them up so that they're as close as
    666   // possible to ops with 32-bit outputs like BiasAdd or Conv2D.
    667   GraphDef hoisted_graph_def;
    669       HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
    670   TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
    672   // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
    673   // activation layers, into Requantize ops with those ranges instead. This
    674   // makes it easier to replace the dynamic range calculations that are used
    675   // by default.
    676   GraphDef converted_graph_def;
    677   TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
    678                                                    &converted_graph_def));
    679   TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
    681   // If fallback_min and fallback_max are set, then we'll use hardwired ranges
    682   // for all the 32-bit to 8-bit requantizations.
    683   float fallback_min;
    684   float fallback_max;
    685   bool has_fallback_range;
    686   TF_RETURN_IF_ERROR(ExtractRangeFromParams(
    687       context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
    688       &has_fallback_range));
    690   // Replace all occurrences of the current float op with its quantized
    691   // equivalent.
    692   GraphDef quantized_graph_def;
    693   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    694       converted_graph_def, {op_pattern},
    695       [&op_map, fallback_min, fallback_max, has_fallback_range](
    696           const NodeMatch& match, const std::set<string>& input_nodes,
    697           const std::set<string>& output_nodes,
    698           std::vector<NodeDef>* new_nodes) {
    699         const NodeDef& float_node = match.node;
    700         const QuantizedOpInfo& op_info = op_map[float_node.op()];
    702         DataTypeVector input_types;
    703         DataTypeVector output_types;
    704         TF_RETURN_IF_ERROR(
    705             GetInOutTypes(float_node, &input_types, &output_types));
    706         bool are_all_float = true;
    707         for (int i = 0; i < float_node.input_size(); ++i) {
    708           // Skip any known non-float inputs.
    709           if (op_info.unquantized_inputs.count(i)) {
    710             continue;
    711           }
    712           if (input_types[i] != DT_FLOAT) {
    713             are_all_float = false;
    714           }
    715         }
    716         for (const DataType& output_type : output_types) {
    717           if (output_type != DT_FLOAT) {
    718             are_all_float = false;
    719           }
    720         }
    721         // This isn't a float op, so don't quantize it.
    722         if (!are_all_float) {
    723           CopyOriginalMatch(match, new_nodes);
    724           return Status::OK();
    725         }
    727         string namespace_prefix = float_node.name() + "_eightbit";
    729         // Quantize all of the inputs.
    730         std::vector<string> quantized_input_names;
    731         for (int i = 0; i < float_node.input_size(); ++i) {
    732           // Skip any non-float inputs.
    733           if (op_info.unquantized_inputs.count(i)) {
    734             continue;
    735           }
    737           const string& input_name = float_node.input(i);
    738           string unique_input_name =
    739               namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
    741           // Add some common constants we need for reshaping inputs.
    742           NodeDef reshape_dims;
    743           reshape_dims.set_op("Const");
    744           reshape_dims.set_name(unique_input_name + "/reshape_dims");
    745           AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims);
    746           SetNodeAttr("dtype", DT_INT32, &reshape_dims);
    747           Tensor reshape_dims_tensor(DT_INT32, {1});
    748           reshape_dims_tensor.flat<int32>()(0) = -1;
    749           SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
    750           new_nodes->push_back(reshape_dims);
    752           NodeDef reduction_dims;
    753           reduction_dims.set_op("Const");
    754           reduction_dims.set_name(unique_input_name + "/reduction_dims");
    755           AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims);
    756           SetNodeAttr("dtype", DT_INT32, &reduction_dims);
    757           Tensor reduction_dims_tensor(DT_INT32, {1});
    758           reduction_dims_tensor.flat<int32>()(0) = 0;
    759           SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
    760                                    &reduction_dims);
    761           new_nodes->push_back(reduction_dims);
    763           NodeDef reshape_node;
    764           reshape_node.set_op("Reshape");
    765           reshape_node.set_name(unique_input_name + "/reshape");
    766           SetNodeAttr("T", DT_FLOAT, &reshape_node);
    767           AddNodeInput(input_name, &reshape_node);
    768           AddNodeInput(reshape_dims.name(), &reshape_node);
    769           new_nodes->push_back(reshape_node);
    771           NodeDef min_node;
    772           min_node.set_op("Min");
    773           min_node.set_name(unique_input_name + "/min");
    774           SetNodeAttr("T", DT_FLOAT, &min_node);
    775           SetNodeAttr("keep_dims", false, &min_node);
    776           AddNodeInput(reshape_node.name(), &min_node);
    777           AddNodeInput(reduction_dims.name(), &min_node);
    778           new_nodes->push_back(min_node);
    780           NodeDef max_node;
    781           max_node.set_op("Max");
    782           max_node.set_name(unique_input_name + "/max");
    783           SetNodeAttr("T", DT_FLOAT, &max_node);
    784           SetNodeAttr("keep_dims", false, &max_node);
    785           AddNodeInput(reshape_node.name(), &max_node);
    786           AddNodeInput(reduction_dims.name(), &max_node);
    787           new_nodes->push_back(max_node);
    789           NodeDef quantize_node;
    790           quantize_node.set_op("QuantizeV2");
    791           quantize_node.set_name(unique_input_name + "/quantize");
    792           SetNodeAttr("T", DT_QUINT8, &quantize_node);
    793           SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
    794           AddNodeInput(input_name, &quantize_node);
    795           AddNodeInput(min_node.name(), &quantize_node);
    796           AddNodeInput(max_node.name(), &quantize_node);
    797           new_nodes->push_back(quantize_node);
    798           quantized_input_names.push_back(quantize_node.name());
    799         }
    801         // Set up the quantized version of the current op.
    802         NodeDef quantized_main_node;
    803         quantized_main_node.set_op("Quantized" + float_node.op());
    804         quantized_main_node.set_name(float_node.name() + "/eightbit");
    805         for (const string& attr_to_copy : op_info.attrs_to_copy) {
    806           CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
    807                        &quantized_main_node);
    808         }
    809         for (const std::pair<string, DataType>& dtype_to_set :
    810              op_info.dtypes_to_set) {
    811           SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
    812                       &quantized_main_node);
    813         }
    814         int quantized_input_index = 0;
    815         for (int i = 0; i < float_node.input_size(); ++i) {
    816           if (op_info.unquantized_inputs.count(i)) {
    817             AddNodeInput(float_node.input(i), &quantized_main_node);
    818           } else {
    819             const string& quantized_input_name =
    820                 quantized_input_names[quantized_input_index];
    821             AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
    822             ++quantized_input_index;
    823           }
    824         }
    825         if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
    826           for (const string& quantized_input_name : quantized_input_names) {
    827             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
    828             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
    829           }
    830         } else {
    831           for (const string& quantized_input_name : quantized_input_names) {
    832             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
    833           }
    834           for (const string& quantized_input_name : quantized_input_names) {
    835             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
    836           }
    837         }
    838         new_nodes->push_back(quantized_main_node);
    840         string eight_bit_node_name;
    841         if (op_info.output_bit_depth == DT_QINT32) {
    842           // Shrink the range of the output down from 32 bits to 8.
    843           string requantize_min_input;
    844           string requantize_max_input;
    845           if (has_fallback_range) {
    846             // Use constant values for the min/max range if they were given.
    847             NodeDef fallback_min_node;
    848             fallback_min_node.set_op("Const");
    849             fallback_min_node.set_name(quantized_main_node.name() +
    850                                        "/fallback_min");
    851             SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
    852             Tensor fallback_min_tensor(DT_FLOAT, {});
    853             fallback_min_tensor.flat<float>()(0) = fallback_min;
    854             SetNodeTensorAttr<float>("value", fallback_min_tensor,
    855                                      &fallback_min_node);
    856             new_nodes->push_back(fallback_min_node);
    858             NodeDef fallback_max_node;
    859             fallback_max_node.set_op("Const");
    860             fallback_max_node.set_name(quantized_main_node.name() +
    861                                        "/fallback_max");
    862             SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
    863             Tensor fallback_max_tensor(DT_FLOAT, {});
    864             fallback_max_tensor.flat<float>()(0) = fallback_max;
    865             SetNodeTensorAttr<float>("value", fallback_max_tensor,
    866                                      &fallback_max_node);
    867             new_nodes->push_back(fallback_max_node);
    869             requantize_min_input = fallback_min_node.name();
    870             requantize_max_input = fallback_max_node.name();
    871           } else {
    872             // Otherwise dynamically measure the range each time.
    873             NodeDef requant_range_node;
    874             requant_range_node.set_op("RequantizationRange");
    875             requant_range_node.set_name(quantized_main_node.name() +
    876                                         "/requant_range");
    877             SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
    878             AddNodeInput(quantized_main_node.name() + ":0",
    879                          &requant_range_node);
    880             AddNodeInput(quantized_main_node.name() + ":1",
    881                          &requant_range_node);
    882             AddNodeInput(quantized_main_node.name() + ":2",
    883                          &requant_range_node);
    884             new_nodes->push_back(requant_range_node);
    886             requantize_min_input = requant_range_node.name() + ":0";
    887             requantize_max_input = requant_range_node.name() + ":1";
    888           }
    889           NodeDef requantize_node;
    890           requantize_node.set_op("Requantize");
    891           requantize_node.set_name(quantized_main_node.name() + "/requantize");
    892           SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
    893           SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
    894           AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
    895           AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
    896           AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
    897           AddNodeInput(requantize_min_input, &requantize_node);
    898           AddNodeInput(requantize_max_input, &requantize_node);
    899           new_nodes->push_back(requantize_node);
    900           eight_bit_node_name = requantize_node.name();
    901         } else {
    902           eight_bit_node_name = quantized_main_node.name();
    903         }
    905         // Convert the 8-bit result back into float for the final output.
    906         NodeDef dequantize_node;
    907         dequantize_node.set_op("Dequantize");
    908         dequantize_node.set_name(float_node.name());
    909         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
    910         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
    911         AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
    912         AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
    913         AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
    914         new_nodes->push_back(dequantize_node);
    916         return Status::OK();
    917       },
    918       {}, &quantized_graph_def));
    919   TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
    921   // If we've ended up with two Requantize ops in a row (for example if there
    922   // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
    923   // using the trained range from the second op.
    924   GraphDef merged_graph_def;
    925   TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
    926                                               &merged_graph_def));
    927   TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
    929   // There can be duplicate quantize nodes if multiple ops pull from a single
    930   // input, which makes it harder to remove redundant ones, so strip them out.
    931   GraphDef deduped_graph_def;
    933       MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
    934   TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
    936   // Look for Dequantizes that immediately go into Quantizes, and remove them
    937   // since the two together cancel each other out. This allows us to keep the
    938   // data flow in eight bit where two adjacent ops are in eight bit, but still
    939   // keep interoperability with float ops.
    940   TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
    941                                                   output_graph_def));
    942   TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def));
    944   return Status::OK();
    945 }
    947 REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
    949 REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
    951 }  // namespace graph_transforms
    952 }  // namespace tensorflow