Home | History | Annotate | Download | only in graph
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <algorithm>
     17 #include <atomic>
     18 #include <set>
     19 #include <unordered_map>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/graph/quantize_training.h"
     23 
     24 #include "tensorflow/core/common_runtime/executor.h"
     25 #include "tensorflow/core/common_runtime/function.h"
     26 #include "tensorflow/core/common_runtime/memory_types.h"
     27 #include "tensorflow/core/framework/log_memory.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/graph/algorithm.h"
     30 #include "tensorflow/core/graph/graph_constructor.h"
     31 #include "tensorflow/core/graph/node_builder.h"
     32 #include "tensorflow/core/graph/subgraph.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/public/session_options.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 // TODO(suharshs): If desired, make these values configurable.
     40 const uint32 kAllowedInputs = 2;
     41 const float kEMADecay = 0.999;
     42 
     43 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
     44 const auto* nodes_to_rewrite =
     45     new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"};
     46 
     47 // Contains necessary parameters to convert an edge.
     48 struct EdgeToConvert {
     49   // edge is not owned here.
     50   const Edge* edge;
     51   int32 num_bits;
     52   bool signed_input;
     53   bool range_given;
     54   float input_min;
     55   float input_max;
     56 
     57   EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
     58                 float max)
     59       : edge(e),
     60         num_bits(bits),
     61         signed_input(sign),
     62         range_given(range),
     63         input_min(min),
     64         input_max(max) {}
     65 };
     66 
     67 // Decide if a node is in backward pass by checking if its name is led by
     68 // "gradients".
     69 // TODO(jmchen): Make this check more robust as it is not guaranteed that the
     70 // forward node will not be named with a leading "gradients".
     71 inline bool IsGradientNode(const Graph* graph, const Node* node) {
     72   static const string tag = "gradients";
     73   return (node->name().compare(0, tag.size(), tag) == 0);
     74 }
     75 
     76 // Find the type of the input to set the parameters for the
     77 // quantize_and_dequantize op.
     78 // Returns true if the root tensor op type is known, false otherwise.
     79 bool FindType(const Graph* graph, const Node* node, bool* signed_input,
     80               bool* range_given, float* input_min, float* input_max) {
     81   const string& src_op = node->type_string();
     82   if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") {
     83     *signed_input = true;
     84     *range_given = false;
     85   } else if (src_op == "Relu") {
     86     // Range is not given for Relu.
     87     *signed_input = false;
     88     *range_given = false;
     89   } else if (src_op == "Relu6") {
     90     // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
     91     // actual activations are somewhere in within this range, we can quantize
     92     // this even further. This is true for other activations like Sigmoid6 too.
     93     *signed_input = false;
     94     *range_given = true;
     95     *input_min = 0;
     96     *input_max = 6;
     97   } else if (src_op == "Sigmoid") {
     98     *signed_input = false;
     99     *range_given = true;
    100     *input_min = 0;
    101     *input_max = 1;
    102   } else if (src_op == "Tanh") {
    103     *signed_input = true;
    104     *range_given = true;
    105     *input_min = -1;
    106     *input_max = 1;
    107   } else if (src_op == "Reshape" || src_op == "ConcatV2") {
    108     // Reshape has 2 inputs and the first one is the tensor.
    109     // ConcatV2 has many inputs but they should all have the same activation
    110     // function (i.e. Inception). So we just recurse on the first input.
    111     for (const Edge* edge : node->in_edges()) {
    112       if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
    113         FindType(graph, edge->src(), signed_input, range_given, input_min,
    114                  input_max);
    115       }
    116     }
    117   } else if (src_op == "Identity" || src_op == "MaxPool" ||
    118              src_op == "AvgPool" || src_op == "MaxPool3D" ||
    119              src_op == "AvgPool3D") {
    120     // All these Ops only have 1 data input.
    121     for (const Edge* edge : node->in_edges()) {
    122       if (edge->src_output() != Graph::kControlSlot) {
    123         FindType(graph, edge->src(), signed_input, range_given, input_min,
    124                  input_max);
    125       }
    126     }
    127   } else {
    128     // Unknown type, could be the model input examples.
    129     // TODO(jmchen): Set the params for input with user's hint.
    130     *signed_input = true;
    131     *range_given = false;
    132     return false;
    133   }
    134 
    135   return true;
    136 }
    137 
    138 // Find the Save op and inputs.
    139 Status FindSaveOp(const Graph* graph, Node** save_op,
    140                   std::vector<const Edge*>* in_edges, bool* found) {
    141   *found = false;
    142   for (Node* node : graph->op_nodes()) {
    143     if (node->type_string() == "SaveV2") {
    144       // We found multiple save ops.
    145       if (*found) {
    146         return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
    147       }
    148       *save_op = node;
    149       *found = true;
    150       TF_RETURN_IF_ERROR(node->input_edges(in_edges));
    151     }
    152   }
    153   return Status::OK();
    154 }
    155 
    156 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
    157   for (Node* node : graph->op_nodes()) {
    158     // The restore_all op should have the same prefix of the save_op.
    159     if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
    160       return node;
    161     }
    162   }
    163   return nullptr;
    164 }
    165 
    166 // Strips the last "/suffix" from a name.
    167 // We use this to construct the name of restore ops in the same way they are
    168 // constructed by the Saver.
    169 StringPiece GetNodeNamePrefix(const Node* node) {
    170   StringPiece name = node->name();
    171   return name.substr(0, name.rfind('/'));
    172 }
    173 
    174 void FillStringTensor(Tensor* dst, const Tensor& src) {
    175   auto dst_flat = dst->flat<string>();
    176   auto src_flat = src.flat<string>();
    177   for (int i = 0; i < src.NumElements(); i++) {
    178     dst_flat(i) = src_flat(i);
    179   }
    180 }
    181 
    182 // Add the added_variables as an inputs to the Save op.
    183 // We change the inputs of the SaveV2 op to include the names of the added
    184 // variables. We also add the variables as inputs to the save op.
    185 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
    186                                 const std::vector<const Edge*>& in_edges,
    187                                 const std::vector<Node*>& added_variables) {
    188   Node* tensor_names_op = in_edges[1]->src();
    189   Node* shape_and_slices_op = in_edges[2]->src();
    190 
    191   // Get the tensor_names and shape_and_slices tensors from the const op.
    192   Tensor tensor_names;
    193   Tensor shape_and_slices;
    194   TF_RETURN_IF_ERROR(
    195       GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
    196   TF_RETURN_IF_ERROR(
    197       GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
    198 
    199   int tn_size = tensor_names.NumElements();
    200   int var_size = added_variables.size();
    201 
    202   // Create a new save_op that has inputs to all the new variables.
    203   NodeBuilder save_op_builder =
    204       NodeBuilder(save_op->name(), save_op->type_string());
    205   // The first three inputs are prefix, tensor_names, and shapes_and_slices.
    206   for (int i = 0; i < 3; i++) {
    207     save_op_builder = save_op_builder.Input(in_edges[i]->src());
    208   }
    209   std::vector<NodeBuilder::NodeOut> var_nodeouts;
    210   var_nodeouts.reserve(tn_size + var_size);
    211   // The rest of the inputs need to be used the construct the tensor list arg.
    212   for (int i = 3; i < in_edges.size(); i++) {
    213     var_nodeouts.emplace_back(in_edges[i]->src());
    214   }
    215 
    216   // Add the new values to the tensors and the op input.
    217   Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
    218   Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
    219   FillStringTensor(&new_tensor_names, tensor_names);
    220   FillStringTensor(&new_shape_and_slices, shape_and_slices);
    221   for (int i = 0; i < var_size; i++) {
    222     Node* var = added_variables[i];
    223     new_tensor_names.flat<string>()(tn_size + i) = var->name();
    224     new_shape_and_slices.flat<string>()(tn_size + i) = "";
    225     var_nodeouts.emplace_back(var);
    226   }
    227   save_op_builder = save_op_builder.Input(var_nodeouts);
    228 
    229   // Update the attrs.
    230   tensor_names_op->AddAttr("value", new_tensor_names);
    231   shape_and_slices_op->AddAttr("value", new_shape_and_slices);
    232 
    233   // Remove the old save_op and add the new one.
    234   Node* new_save_op;
    235   TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
    236   // Add outputs to the new_save_op, all outputs are control edges.
    237   for (const Edge* edge : save_op->out_edges()) {
    238     graph->AddControlEdge(new_save_op, edge->dst());
    239   }
    240   graph->RemoveNode(save_op);
    241 
    242   return Status::OK();
    243 }
    244 
    245 // Add a restore subgraph for each variable and connect to the restore_all op.
    246 // For each variable we add the following subgraph:
    247 //           Assign----restore_all
    248 //          |      |
    249 //   RestoreV2    Variable
    250 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
    251                                    const std::vector<const Edge*>& in_edges,
    252                                    const std::vector<Node*>& variables) {
    253   Node* prefix_op = in_edges[0]->src();
    254   StringPiece name_prefix = GetNodeNamePrefix(save_op);
    255   Node* restore_all = FindRestoreAllOp(graph, name_prefix);
    256   if (restore_all == nullptr) {
    257     return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
    258   }
    259   const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
    260   const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
    261   for (Node* var : variables) {
    262     string new_restore_op_name = graph->NewName(restore_op_name);
    263     string new_assign_op_name = graph->NewName(assign_op_name);
    264     string tensor_names_op_name =
    265         strings::StrCat(new_restore_op_name, "/tensor_names");
    266     string shape_and_slices_op_name =
    267         strings::StrCat(new_restore_op_name, "/shape_and_slices");
    268 
    269     // Construct the tensor_names input with the variable name.
    270     Node* tensor_names;
    271     Tensor tensor_names_val(DT_STRING, TensorShape({1}));
    272     tensor_names_val.flat<string>()(0) = var->name();
    273     TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
    274                            .Attr("dtype", DT_STRING)
    275                            .Attr("value", tensor_names_val)
    276                            .Finalize(graph, &tensor_names));
    277 
    278     // Construct the shape_and_slices input with empty string.
    279     Node* shape_and_slices;
    280     Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
    281     shape_and_slices_val.flat<string>()(0) = "";
    282     TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
    283                            .Attr("dtype", DT_STRING)
    284                            .Attr("value", shape_and_slices_val)
    285                            .Finalize(graph, &shape_and_slices));
    286 
    287     // Build the new Restore op for this variable.
    288     Node* restore_op;
    289     TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
    290                            .Input(prefix_op)
    291                            .Input(tensor_names)
    292                            .Input(shape_and_slices)
    293                            .Attr("dtypes", {DT_FLOAT})
    294                            .Finalize(graph, &restore_op));
    295 
    296     // Create Assign op, attaching the variable and Restore op to it.
    297     Node* assign_op;
    298     TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
    299                            .Input(var)
    300                            .Input(restore_op)
    301                            .Finalize(graph, &assign_op));
    302 
    303     // Add a control edge from the assign op to restore_all op.
    304     graph->AddControlEdge(assign_op, restore_all);
    305   }
    306   return Status::OK();
    307 }
    308 
    309 // Adds new variables to save and restore ops matching the Save and Restore
    310 // graphs created in tensorflow/python/training/saver.py.
    311 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
    312   Node* save_op = nullptr;
    313   std::vector<const Edge*> in_edges;
    314   bool found = false;
    315   TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
    316   if (found) {
    317     TF_RETURN_IF_ERROR(
    318         AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
    319     TF_RETURN_IF_ERROR(
    320         ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
    321   }
    322   return Status::OK();
    323 }
    324 
    325 // Sets output to the Node that computes reduction axes corresponding to all
    326 // dimensions of input and return.
    327 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
    328                          Node** output) {
    329   name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
    330   Node* start;
    331   Tensor zero_tensor(DT_INT32, TensorShape());
    332   zero_tensor.flat<int32>()(0) = 0;
    333   TF_RETURN_IF_ERROR(
    334       NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
    335           .Attr("dtype", DT_INT32)
    336           .Attr("value", zero_tensor)
    337           .Finalize(graph, &start));
    338   Node* delta;
    339   Tensor one_tensor(DT_INT32, TensorShape());
    340   one_tensor.flat<int32>()(0) = 1;
    341   TF_RETURN_IF_ERROR(
    342       NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
    343           .Attr("dtype", DT_INT32)
    344           .Attr("value", one_tensor)
    345           .Finalize(graph, &delta));
    346   Node* rank;
    347   TF_RETURN_IF_ERROR(
    348       NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
    349           .Input(input)
    350           .Finalize(graph, &rank));
    351   TF_RETURN_IF_ERROR(
    352       NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
    353           .Input(start)
    354           .Input(rank)
    355           .Input(delta)
    356           .Finalize(graph, output));
    357   return Status::OK();
    358 }
    359 
    360 // Computes the exponential moving average of input, updated in update_variable.
    361 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
    362                                     const NodeBuilder::NodeOut& input,
    363                                     Node* decay, Node* update_variable,
    364                                     Node** assign_value) {
    365   // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
    366   name_prefix = strings::StrCat(name_prefix, "/EMA");
    367   Node* one;
    368   Tensor one_tensor(DT_FLOAT, TensorShape());
    369   one_tensor.flat<float>()(0) = 1.0;
    370   TF_RETURN_IF_ERROR(
    371       NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
    372           .Attr("dtype", DT_FLOAT)
    373           .Attr("value", one_tensor)
    374           .Finalize(graph, &one));
    375   Node* decay_complement;
    376   TF_RETURN_IF_ERROR(
    377       NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
    378           .Input(one)
    379           .Input(decay)
    380           .Finalize(graph, &decay_complement));
    381 
    382   Node* value_diff;
    383   TF_RETURN_IF_ERROR(
    384       NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
    385           .Input(update_variable)
    386           .Input(input)
    387           .Finalize(graph, &value_diff));
    388   Node* update_value;
    389   TF_RETURN_IF_ERROR(
    390       NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
    391           .Input(value_diff)
    392           .Input(decay_complement)
    393           .Finalize(graph, &update_value));
    394 
    395   TF_RETURN_IF_ERROR(
    396       NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
    397           .Input(update_variable)
    398           .Input(update_value)
    399           .Finalize(graph, assign_value));
    400   return Status::OK();
    401 }
    402 
    403 // Creates an automatically initialized exponential moving average variable.
    404 // This uses a switch op to assign a value to the variable on the first run,
    405 // and update with the moving average for all other runs:
    406 //                   init_val
    407 //                      |
    408 //      var--is_init--switch
    409 //       |      true /      \ false
    410 //       |          |        |
    411 //       |         EMA    init_val
    412 //       |           \      /
    413 //       +----------- assign
    414 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
    415                                   Node* init_val,
    416                                   std::vector<Node*>* added_variables,
    417                                   Node** var) {
    418   // TODO(suharshs): Update this to use ResourceVariables when they are ready.
    419   TF_RETURN_IF_ERROR(
    420       NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
    421           .Attr("shape", TensorShape())
    422           .Attr("dtype", DT_FLOAT)
    423           .Finalize(graph, var));
    424   added_variables->push_back(*var);
    425 
    426   Node* is_initialized;
    427   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
    428                                  "IsVariableInitialized")
    429                          .Input(*var)
    430                          .Finalize(graph, &is_initialized));
    431   Node* switch_node;
    432   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
    433                          .Input(init_val)
    434                          .Input(is_initialized)
    435                          .Finalize(graph, &switch_node));
    436   NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
    437   NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
    438 
    439   Node* ema_value;
    440   TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
    441                                                   decay, *var, &ema_value));
    442 
    443   Node* assign_value;
    444   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
    445                          .Input({output_false, ema_value})
    446                          .Finalize(graph, &assign_value));
    447 
    448   TF_RETURN_IF_ERROR(
    449       NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
    450           .Input(*var)
    451           .Input(assign_value)
    452           .Finalize(graph, var));
    453   return Status::OK();
    454 }
    455 
    456 // Computes the min and max EMA of input and stores them in min_var and max_var.
    457 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
    458                          std::vector<Node*>* added_variables, Node** min_var,
    459                          Node** max_var) {
    460   // TODO(suharshs): The decay will be constant, so we could make only one for
    461   // all quantize_and_dequantize ops to share, this would have to live outside
    462   // this function.
    463   Tensor decay_tensor(DT_FLOAT, TensorShape());
    464   decay_tensor.flat<float>()(0) = kEMADecay;
    465   Node* decay;
    466   TF_RETURN_IF_ERROR(
    467       NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
    468           .Attr("dtype", DT_FLOAT)
    469           .Attr("value", decay_tensor)
    470           .Finalize(graph, &decay));
    471 
    472   Node* reduction_axes;
    473   TF_RETURN_IF_ERROR(
    474       MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
    475   Node* min;
    476   string min_name = strings::StrCat(name_prefix, "/Min");
    477   TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
    478                          .Input(input)
    479                          .Input(reduction_axes)
    480                          .Finalize(graph, &min));
    481   Node* max;
    482   string max_name = strings::StrCat(name_prefix, "/Max");
    483   TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
    484                          .Input(input)
    485                          .Input(reduction_axes)
    486                          .Finalize(graph, &max));
    487   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
    488                                                 added_variables, min_var));
    489   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
    490                                                 added_variables, max_var));
    491   return Status::OK();
    492 }
    493 
    494 // Makes an input min and max constant if the range is given. Otherwise, makes
    495 // min and max variables that are updated by an EMA.
    496 Status MakeInputMinMax(Graph* graph, const string& name_prefix,
    497                        const EdgeToConvert& edge,
    498                        std::vector<Node*>* added_variables, Node** input_min,
    499                        Node** input_max) {
    500   if (edge.range_given) {
    501     // Make constant nodes for the input_min and input_max if the range is
    502     // provided.
    503     Tensor input_min_tensor(DT_FLOAT, TensorShape());
    504     input_min_tensor.flat<float>()(0) = edge.input_min;
    505     TF_RETURN_IF_ERROR(
    506         NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
    507             .Attr("dtype", DT_FLOAT)
    508             .Attr("value", input_min_tensor)
    509             .Finalize(graph, input_min));
    510     Tensor input_max_tensor(DT_FLOAT, TensorShape());
    511     input_max_tensor.flat<float>()(0) = edge.input_max;
    512     TF_RETURN_IF_ERROR(
    513         NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
    514             .Attr("dtype", DT_FLOAT)
    515             .Attr("value", input_max_tensor)
    516             .Finalize(graph, input_max));
    517   } else {
    518     // If the range is not given, estimate the range with EMA variables.
    519     TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
    520                                          added_variables, input_min,
    521                                          input_max));
    522   }
    523 
    524   return Status::OK();
    525 }
    526 
    527 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op
    528 // (and required input nodes) based on edge.
    529 // The result is stored in convert_node.
    530 Status MakeQuantizeOp(Graph* graph, const string& name_prefix,
    531                       const string& quant_op_type, const EdgeToConvert& edge,
    532                       std::vector<Node*>* added_variables,
    533                       Node** convert_node) {
    534   Node* input_min;
    535   Node* input_max;
    536   TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
    537                                      &input_min, &input_max));
    538   string quant_name = strings::StrCat(name_prefix, "/", quant_op_type);
    539   if (quant_op_type == "QuantizeAndDequantizeV2") {
    540     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
    541                            .Input(edge.edge->src())
    542                            .Input(input_min)
    543                            .Input(input_max)
    544                            .Attr("signed_input", edge.signed_input)
    545                            .Attr("num_bits", edge.num_bits)
    546                            .Attr("range_given", true)
    547                            .Finalize(graph, convert_node));
    548   } else if (quant_op_type == "FakeQuantWithMinMaxVars") {
    549     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
    550                            .Input(edge.edge->src())
    551                            .Input(input_min)
    552                            .Input(input_max)
    553                            .Attr("num_bits", edge.num_bits)
    554                            .Finalize(graph, convert_node));
    555   } else {
    556     return errors::InvalidArgument("Unknown quant op type: ", quant_op_type);
    557   }
    558   return Status::OK();
    559 }
    560 
    561 // Insert conversion op, connect it to the graph and remove the old edge.
    562 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type,
    563                           const std::vector<EdgeToConvert>& target_edges) {
    564   // Remember previously converted ops to avoid duplicated conversion on the
    565   // same input.
    566   std::unordered_map<string, Node*, StringPieceHasher> name_index;
    567   std::vector<Node*> added_variables;
    568   for (const EdgeToConvert edge : target_edges) {
    569     Node* convert_node;
    570     string name_prefix = edge.edge->src()->name();
    571 
    572     auto iter = name_index.find(name_prefix);
    573     if (iter == name_index.end()) {
    574       TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge,
    575                                         &added_variables, &convert_node));
    576       name_index[name_prefix] = convert_node;
    577     } else {
    578       convert_node = iter->second;
    579     }
    580 
    581     graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
    582     graph->RemoveEdge(edge.edge);
    583   }
    584 
    585   TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
    586 
    587   return Status::OK();
    588 }
    589 
    590 }  // namespace
    591 
    592 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
    593                           Graph* graph) {
    594   if (graph == nullptr) {
    595     return errors::InvalidArgument("Cannot accept empty graph pointer.");
    596   }
    597 
    598   if (num_bits < 1 || num_bits > 63) {
    599     return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
    600                               num_bits);
    601   }
    602   int potential_input = 0;
    603   std::vector<EdgeToConvert> target_edges;
    604   for (Node* node : graph->nodes()) {
    605     if (nodes_to_rewrite->find(node->type_string()) !=
    606             nodes_to_rewrite->end() &&
    607         !IsGradientNode(graph, node)) {
    608       // Find out which types are the inputs and convert them accordingly.
    609       // 1. Const/Variable OP: This is quantized as signed tensors with no given
    610       // range.
    611       // 2. Activation OP: Set the range accordingly for different types of
    612       // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
    613       // 3. Identity OP: The quantization parameters depend on its input.
    614       // 4. Pooling OPs: various pooling ops. Also depends on its input.
    615       // 5. Reshape OP: Also depends on the first input to this op.
    616       // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
    617       // model input. However, if there are >1 unknown ops, then returns an
    618       // error for now to avoid unexpected bahavior.
    619       // Note: The list above might not be a complete list. Please let us
    620       // know if you see the error so we can handle your case.
    621       for (const Edge* edge : node->in_edges()) {
    622         if (edge->src_output() == Graph::kControlSlot) {
    623           // Skip the control dependency input.
    624           continue;
    625         } else {
    626           bool signed_input = false;
    627           bool range_given = false;
    628           float input_min = 0;
    629           float input_max = 0;
    630           bool known_op = FindType(graph, edge->src(), &signed_input,
    631                                    &range_given, &input_min, &input_max);
    632           if (!known_op) {
    633             // Unknown op is considered as input.
    634             potential_input++;
    635             if (potential_input > kAllowedInputs) {
    636               return errors::Unimplemented(
    637                   "Found an unknown op: ", edge->src()->name(),
    638                   " with type: ", edge->src()->type_string(),
    639                   "; Unknown ops are considered as model input for now and "
    640                   "only ",
    641                   kAllowedInputs, " inputs are supported currently.");
    642             }
    643           }
    644 
    645           target_edges.emplace_back(EdgeToConvert(
    646               edge, num_bits, signed_input, range_given, input_min, input_max));
    647         }
    648       }
    649     }
    650   }
    651 
    652   TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges));
    653 
    654   return Status::OK();
    655 }
    656 
    657 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
    658                                     int32 num_bits, const string& quant_op_type,
    659                                     GraphDef* result_graphdef) {
    660   Graph graph(OpRegistry::Global());
    661   GraphConstructorOptions opts;
    662   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
    663 
    664   // Call the rewriter on the graph.
    665   TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
    666 
    667   // Convert the result graph back to a GraphDef.
    668   graph.ToGraphDef(result_graphdef);
    669   return Status::OK();
    670 }
    671 
    672 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
    673                                               int32 num_bits,
    674                                               const string& quant_op_type,
    675                                               string* result_graph_string) {
    676   // First create the graph from the GraphDef.
    677   GraphDef input_graphdef;
    678   if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
    679     return errors::InvalidArgument(
    680         "input_graph_string is not a serialized GraphDef protocol buffer");
    681   }
    682   GraphDef output_graphdef;
    683   TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
    684       input_graphdef, num_bits, quant_op_type, &output_graphdef));
    685 
    686   if (!output_graphdef.SerializeToString(result_graph_string)) {
    687     return errors::Internal(
    688         "quantize training transformation resulted in invalid GraphDef");
    689   }
    690   return Status::OK();
    691 }
    692 
    693 }  // namespace tensorflow
    694