1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <algorithm> 17 #include <atomic> 18 #include <set> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/graph/quantize_training.h" 23 24 #include "tensorflow/core/common_runtime/executor.h" 25 #include "tensorflow/core/common_runtime/function.h" 26 #include "tensorflow/core/common_runtime/memory_types.h" 27 #include "tensorflow/core/framework/log_memory.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/graph/algorithm.h" 30 #include "tensorflow/core/graph/graph_constructor.h" 31 #include "tensorflow/core/graph/node_builder.h" 32 #include "tensorflow/core/graph/subgraph.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 namespace { 38 39 // TODO(suharshs): If desired, make these values configurable. 40 const uint32 kAllowedInputs = 2; 41 const float kEMADecay = 0.999; 42 43 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs. 44 const auto* nodes_to_rewrite = 45 new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"}; 46 47 // Contains necessary parameters to convert an edge. 48 struct EdgeToConvert { 49 // edge is not owned here. 50 const Edge* edge; 51 int32 num_bits; 52 bool signed_input; 53 bool range_given; 54 float input_min; 55 float input_max; 56 57 EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min, 58 float max) 59 : edge(e), 60 num_bits(bits), 61 signed_input(sign), 62 range_given(range), 63 input_min(min), 64 input_max(max) {} 65 }; 66 67 // Decide if a node is in backward pass by checking if its name is led by 68 // "gradients". 69 // TODO(jmchen): Make this check more robust as it is not guaranteed that the 70 // forward node will not be named with a leading "gradients". 71 inline bool IsGradientNode(const Graph* graph, const Node* node) { 72 static const string tag = "gradients"; 73 return (node->name().compare(0, tag.size(), tag) == 0); 74 } 75 76 // Find the type of the input to set the parameters for the 77 // quantize_and_dequantize op. 78 // Returns true if the root tensor op type is known, false otherwise. 79 bool FindType(const Graph* graph, const Node* node, bool* signed_input, 80 bool* range_given, float* input_min, float* input_max) { 81 const string& src_op = node->type_string(); 82 if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") { 83 *signed_input = true; 84 *range_given = false; 85 } else if (src_op == "Relu") { 86 // Range is not given for Relu. 87 *signed_input = false; 88 *range_given = false; 89 } else if (src_op == "Relu6") { 90 // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the 91 // actual activations are somewhere in within this range, we can quantize 92 // this even further. This is true for other activations like Sigmoid6 too. 93 *signed_input = false; 94 *range_given = true; 95 *input_min = 0; 96 *input_max = 6; 97 } else if (src_op == "Sigmoid") { 98 *signed_input = false; 99 *range_given = true; 100 *input_min = 0; 101 *input_max = 1; 102 } else if (src_op == "Tanh") { 103 *signed_input = true; 104 *range_given = true; 105 *input_min = -1; 106 *input_max = 1; 107 } else if (src_op == "Reshape" || src_op == "ConcatV2") { 108 // Reshape has 2 inputs and the first one is the tensor. 109 // ConcatV2 has many inputs but they should all have the same activation 110 // function (i.e. Inception). So we just recurse on the first input. 111 for (const Edge* edge : node->in_edges()) { 112 if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) { 113 FindType(graph, edge->src(), signed_input, range_given, input_min, 114 input_max); 115 } 116 } 117 } else if (src_op == "Identity" || src_op == "MaxPool" || 118 src_op == "AvgPool" || src_op == "MaxPool3D" || 119 src_op == "AvgPool3D") { 120 // All these Ops only have 1 data input. 121 for (const Edge* edge : node->in_edges()) { 122 if (edge->src_output() != Graph::kControlSlot) { 123 FindType(graph, edge->src(), signed_input, range_given, input_min, 124 input_max); 125 } 126 } 127 } else { 128 // Unknown type, could be the model input examples. 129 // TODO(jmchen): Set the params for input with user's hint. 130 *signed_input = true; 131 *range_given = false; 132 return false; 133 } 134 135 return true; 136 } 137 138 // Find the Save op and inputs. 139 Status FindSaveOp(const Graph* graph, Node** save_op, 140 std::vector<const Edge*>* in_edges, bool* found) { 141 *found = false; 142 for (Node* node : graph->op_nodes()) { 143 if (node->type_string() == "SaveV2") { 144 // We found multiple save ops. 145 if (*found) { 146 return errors::InvalidArgument("Input graph has multiple SaveV2 ops."); 147 } 148 *save_op = node; 149 *found = true; 150 TF_RETURN_IF_ERROR(node->input_edges(in_edges)); 151 } 152 } 153 return Status::OK(); 154 } 155 156 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { 157 for (Node* node : graph->op_nodes()) { 158 // The restore_all op should have the same prefix of the save_op. 159 if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { 160 return node; 161 } 162 } 163 return nullptr; 164 } 165 166 // Strips the last "/suffix" from a name. 167 // We use this to construct the name of restore ops in the same way they are 168 // constructed by the Saver. 169 StringPiece GetNodeNamePrefix(const Node* node) { 170 StringPiece name = node->name(); 171 return name.substr(0, name.rfind('/')); 172 } 173 174 void FillStringTensor(Tensor* dst, const Tensor& src) { 175 auto dst_flat = dst->flat<string>(); 176 auto src_flat = src.flat<string>(); 177 for (int i = 0; i < src.NumElements(); i++) { 178 dst_flat(i) = src_flat(i); 179 } 180 } 181 182 // Add the added_variables as an inputs to the Save op. 183 // We change the inputs of the SaveV2 op to include the names of the added 184 // variables. We also add the variables as inputs to the save op. 185 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, 186 const std::vector<const Edge*>& in_edges, 187 const std::vector<Node*>& added_variables) { 188 Node* tensor_names_op = in_edges[1]->src(); 189 Node* shape_and_slices_op = in_edges[2]->src(); 190 191 // Get the tensor_names and shape_and_slices tensors from the const op. 192 Tensor tensor_names; 193 Tensor shape_and_slices; 194 TF_RETURN_IF_ERROR( 195 GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names)); 196 TF_RETURN_IF_ERROR( 197 GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices)); 198 199 int tn_size = tensor_names.NumElements(); 200 int var_size = added_variables.size(); 201 202 // Create a new save_op that has inputs to all the new variables. 203 NodeBuilder save_op_builder = 204 NodeBuilder(save_op->name(), save_op->type_string()); 205 // The first three inputs are prefix, tensor_names, and shapes_and_slices. 206 for (int i = 0; i < 3; i++) { 207 save_op_builder = save_op_builder.Input(in_edges[i]->src()); 208 } 209 std::vector<NodeBuilder::NodeOut> var_nodeouts; 210 var_nodeouts.reserve(tn_size + var_size); 211 // The rest of the inputs need to be used the construct the tensor list arg. 212 for (int i = 3; i < in_edges.size(); i++) { 213 var_nodeouts.emplace_back(in_edges[i]->src()); 214 } 215 216 // Add the new values to the tensors and the op input. 217 Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size})); 218 Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size})); 219 FillStringTensor(&new_tensor_names, tensor_names); 220 FillStringTensor(&new_shape_and_slices, shape_and_slices); 221 for (int i = 0; i < var_size; i++) { 222 Node* var = added_variables[i]; 223 new_tensor_names.flat<string>()(tn_size + i) = var->name(); 224 new_shape_and_slices.flat<string>()(tn_size + i) = ""; 225 var_nodeouts.emplace_back(var); 226 } 227 save_op_builder = save_op_builder.Input(var_nodeouts); 228 229 // Update the attrs. 230 tensor_names_op->AddAttr("value", new_tensor_names); 231 shape_and_slices_op->AddAttr("value", new_shape_and_slices); 232 233 // Remove the old save_op and add the new one. 234 Node* new_save_op; 235 TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op)); 236 // Add outputs to the new_save_op, all outputs are control edges. 237 for (const Edge* edge : save_op->out_edges()) { 238 graph->AddControlEdge(new_save_op, edge->dst()); 239 } 240 graph->RemoveNode(save_op); 241 242 return Status::OK(); 243 } 244 245 // Add a restore subgraph for each variable and connect to the restore_all op. 246 // For each variable we add the following subgraph: 247 // Assign----restore_all 248 // | | 249 // RestoreV2 Variable 250 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, 251 const std::vector<const Edge*>& in_edges, 252 const std::vector<Node*>& variables) { 253 Node* prefix_op = in_edges[0]->src(); 254 StringPiece name_prefix = GetNodeNamePrefix(save_op); 255 Node* restore_all = FindRestoreAllOp(graph, name_prefix); 256 if (restore_all == nullptr) { 257 return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp"); 258 } 259 const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2"); 260 const string assign_op_name = strings::StrCat(name_prefix, "/Assign"); 261 for (Node* var : variables) { 262 string new_restore_op_name = graph->NewName(restore_op_name); 263 string new_assign_op_name = graph->NewName(assign_op_name); 264 string tensor_names_op_name = 265 strings::StrCat(new_restore_op_name, "/tensor_names"); 266 string shape_and_slices_op_name = 267 strings::StrCat(new_restore_op_name, "/shape_and_slices"); 268 269 // Construct the tensor_names input with the variable name. 270 Node* tensor_names; 271 Tensor tensor_names_val(DT_STRING, TensorShape({1})); 272 tensor_names_val.flat<string>()(0) = var->name(); 273 TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const") 274 .Attr("dtype", DT_STRING) 275 .Attr("value", tensor_names_val) 276 .Finalize(graph, &tensor_names)); 277 278 // Construct the shape_and_slices input with empty string. 279 Node* shape_and_slices; 280 Tensor shape_and_slices_val(DT_STRING, TensorShape({1})); 281 shape_and_slices_val.flat<string>()(0) = ""; 282 TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const") 283 .Attr("dtype", DT_STRING) 284 .Attr("value", shape_and_slices_val) 285 .Finalize(graph, &shape_and_slices)); 286 287 // Build the new Restore op for this variable. 288 Node* restore_op; 289 TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2") 290 .Input(prefix_op) 291 .Input(tensor_names) 292 .Input(shape_and_slices) 293 .Attr("dtypes", {DT_FLOAT}) 294 .Finalize(graph, &restore_op)); 295 296 // Create Assign op, attaching the variable and Restore op to it. 297 Node* assign_op; 298 TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign") 299 .Input(var) 300 .Input(restore_op) 301 .Finalize(graph, &assign_op)); 302 303 // Add a control edge from the assign op to restore_all op. 304 graph->AddControlEdge(assign_op, restore_all); 305 } 306 return Status::OK(); 307 } 308 309 // Adds new variables to save and restore ops matching the Save and Restore 310 // graphs created in tensorflow/python/training/saver.py. 311 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) { 312 Node* save_op = nullptr; 313 std::vector<const Edge*> in_edges; 314 bool found = false; 315 TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found)); 316 if (found) { 317 TF_RETURN_IF_ERROR( 318 AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables)); 319 TF_RETURN_IF_ERROR( 320 ConnectVariablesToSaveOp(graph, save_op, in_edges, variables)); 321 } 322 return Status::OK(); 323 } 324 325 // Sets output to the Node that computes reduction axes corresponding to all 326 // dimensions of input and return. 327 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, 328 Node** output) { 329 name_prefix = strings::StrCat(name_prefix, "/ReductionAxes"); 330 Node* start; 331 Tensor zero_tensor(DT_INT32, TensorShape()); 332 zero_tensor.flat<int32>()(0) = 0; 333 TF_RETURN_IF_ERROR( 334 NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const") 335 .Attr("dtype", DT_INT32) 336 .Attr("value", zero_tensor) 337 .Finalize(graph, &start)); 338 Node* delta; 339 Tensor one_tensor(DT_INT32, TensorShape()); 340 one_tensor.flat<int32>()(0) = 1; 341 TF_RETURN_IF_ERROR( 342 NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const") 343 .Attr("dtype", DT_INT32) 344 .Attr("value", one_tensor) 345 .Finalize(graph, &delta)); 346 Node* rank; 347 TF_RETURN_IF_ERROR( 348 NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank") 349 .Input(input) 350 .Finalize(graph, &rank)); 351 TF_RETURN_IF_ERROR( 352 NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range") 353 .Input(start) 354 .Input(rank) 355 .Input(delta) 356 .Finalize(graph, output)); 357 return Status::OK(); 358 } 359 360 // Computes the exponential moving average of input, updated in update_variable. 361 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, 362 const NodeBuilder::NodeOut& input, 363 Node* decay, Node* update_variable, 364 Node** assign_value) { 365 // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)] 366 name_prefix = strings::StrCat(name_prefix, "/EMA"); 367 Node* one; 368 Tensor one_tensor(DT_FLOAT, TensorShape()); 369 one_tensor.flat<float>()(0) = 1.0; 370 TF_RETURN_IF_ERROR( 371 NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const") 372 .Attr("dtype", DT_FLOAT) 373 .Attr("value", one_tensor) 374 .Finalize(graph, &one)); 375 Node* decay_complement; 376 TF_RETURN_IF_ERROR( 377 NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub") 378 .Input(one) 379 .Input(decay) 380 .Finalize(graph, &decay_complement)); 381 382 Node* value_diff; 383 TF_RETURN_IF_ERROR( 384 NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub") 385 .Input(update_variable) 386 .Input(input) 387 .Finalize(graph, &value_diff)); 388 Node* update_value; 389 TF_RETURN_IF_ERROR( 390 NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul") 391 .Input(value_diff) 392 .Input(decay_complement) 393 .Finalize(graph, &update_value)); 394 395 TF_RETURN_IF_ERROR( 396 NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub") 397 .Input(update_variable) 398 .Input(update_value) 399 .Finalize(graph, assign_value)); 400 return Status::OK(); 401 } 402 403 // Creates an automatically initialized exponential moving average variable. 404 // This uses a switch op to assign a value to the variable on the first run, 405 // and update with the moving average for all other runs: 406 // init_val 407 // | 408 // var--is_init--switch 409 // | true / \ false 410 // | | | 411 // | EMA init_val 412 // | \ / 413 // +----------- assign 414 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, 415 Node* init_val, 416 std::vector<Node*>* added_variables, 417 Node** var) { 418 // TODO(suharshs): Update this to use ResourceVariables when they are ready. 419 TF_RETURN_IF_ERROR( 420 NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2") 421 .Attr("shape", TensorShape()) 422 .Attr("dtype", DT_FLOAT) 423 .Finalize(graph, var)); 424 added_variables->push_back(*var); 425 426 Node* is_initialized; 427 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"), 428 "IsVariableInitialized") 429 .Input(*var) 430 .Finalize(graph, &is_initialized)); 431 Node* switch_node; 432 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch") 433 .Input(init_val) 434 .Input(is_initialized) 435 .Finalize(graph, &switch_node)); 436 NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0); 437 NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1); 438 439 Node* ema_value; 440 TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true, 441 decay, *var, &ema_value)); 442 443 Node* assign_value; 444 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge") 445 .Input({output_false, ema_value}) 446 .Finalize(graph, &assign_value)); 447 448 TF_RETURN_IF_ERROR( 449 NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign") 450 .Input(*var) 451 .Input(assign_value) 452 .Finalize(graph, var)); 453 return Status::OK(); 454 } 455 456 // Computes the min and max EMA of input and stores them in min_var and max_var. 457 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, 458 std::vector<Node*>* added_variables, Node** min_var, 459 Node** max_var) { 460 // TODO(suharshs): The decay will be constant, so we could make only one for 461 // all quantize_and_dequantize ops to share, this would have to live outside 462 // this function. 463 Tensor decay_tensor(DT_FLOAT, TensorShape()); 464 decay_tensor.flat<float>()(0) = kEMADecay; 465 Node* decay; 466 TF_RETURN_IF_ERROR( 467 NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const") 468 .Attr("dtype", DT_FLOAT) 469 .Attr("value", decay_tensor) 470 .Finalize(graph, &decay)); 471 472 Node* reduction_axes; 473 TF_RETURN_IF_ERROR( 474 MakeReductionAxes(graph, name_prefix, input, &reduction_axes)); 475 Node* min; 476 string min_name = strings::StrCat(name_prefix, "/Min"); 477 TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min") 478 .Input(input) 479 .Input(reduction_axes) 480 .Finalize(graph, &min)); 481 Node* max; 482 string max_name = strings::StrCat(name_prefix, "/Max"); 483 TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max") 484 .Input(input) 485 .Input(reduction_axes) 486 .Finalize(graph, &max)); 487 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min, 488 added_variables, min_var)); 489 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max, 490 added_variables, max_var)); 491 return Status::OK(); 492 } 493 494 // Makes an input min and max constant if the range is given. Otherwise, makes 495 // min and max variables that are updated by an EMA. 496 Status MakeInputMinMax(Graph* graph, const string& name_prefix, 497 const EdgeToConvert& edge, 498 std::vector<Node*>* added_variables, Node** input_min, 499 Node** input_max) { 500 if (edge.range_given) { 501 // Make constant nodes for the input_min and input_max if the range is 502 // provided. 503 Tensor input_min_tensor(DT_FLOAT, TensorShape()); 504 input_min_tensor.flat<float>()(0) = edge.input_min; 505 TF_RETURN_IF_ERROR( 506 NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const") 507 .Attr("dtype", DT_FLOAT) 508 .Attr("value", input_min_tensor) 509 .Finalize(graph, input_min)); 510 Tensor input_max_tensor(DT_FLOAT, TensorShape()); 511 input_max_tensor.flat<float>()(0) = edge.input_max; 512 TF_RETURN_IF_ERROR( 513 NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const") 514 .Attr("dtype", DT_FLOAT) 515 .Attr("value", input_max_tensor) 516 .Finalize(graph, input_max)); 517 } else { 518 // If the range is not given, estimate the range with EMA variables. 519 TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(), 520 added_variables, input_min, 521 input_max)); 522 } 523 524 return Status::OK(); 525 } 526 527 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op 528 // (and required input nodes) based on edge. 529 // The result is stored in convert_node. 530 Status MakeQuantizeOp(Graph* graph, const string& name_prefix, 531 const string& quant_op_type, const EdgeToConvert& edge, 532 std::vector<Node*>* added_variables, 533 Node** convert_node) { 534 Node* input_min; 535 Node* input_max; 536 TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, 537 &input_min, &input_max)); 538 string quant_name = strings::StrCat(name_prefix, "/", quant_op_type); 539 if (quant_op_type == "QuantizeAndDequantizeV2") { 540 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) 541 .Input(edge.edge->src()) 542 .Input(input_min) 543 .Input(input_max) 544 .Attr("signed_input", edge.signed_input) 545 .Attr("num_bits", edge.num_bits) 546 .Attr("range_given", true) 547 .Finalize(graph, convert_node)); 548 } else if (quant_op_type == "FakeQuantWithMinMaxVars") { 549 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) 550 .Input(edge.edge->src()) 551 .Input(input_min) 552 .Input(input_max) 553 .Attr("num_bits", edge.num_bits) 554 .Finalize(graph, convert_node)); 555 } else { 556 return errors::InvalidArgument("Unknown quant op type: ", quant_op_type); 557 } 558 return Status::OK(); 559 } 560 561 // Insert conversion op, connect it to the graph and remove the old edge. 562 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, 563 const std::vector<EdgeToConvert>& target_edges) { 564 // Remember previously converted ops to avoid duplicated conversion on the 565 // same input. 566 std::unordered_map<string, Node*, StringPieceHasher> name_index; 567 std::vector<Node*> added_variables; 568 for (const EdgeToConvert edge : target_edges) { 569 Node* convert_node; 570 string name_prefix = edge.edge->src()->name(); 571 572 auto iter = name_index.find(name_prefix); 573 if (iter == name_index.end()) { 574 TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge, 575 &added_variables, &convert_node)); 576 name_index[name_prefix] = convert_node; 577 } else { 578 convert_node = iter->second; 579 } 580 581 graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input()); 582 graph->RemoveEdge(edge.edge); 583 } 584 585 TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables)); 586 587 return Status::OK(); 588 } 589 590 } // namespace 591 592 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type, 593 Graph* graph) { 594 if (graph == nullptr) { 595 return errors::InvalidArgument("Cannot accept empty graph pointer."); 596 } 597 598 if (num_bits < 1 || num_bits > 63) { 599 return errors::OutOfRange("num_bits should be in range [1, 63] but is: ", 600 num_bits); 601 } 602 int potential_input = 0; 603 std::vector<EdgeToConvert> target_edges; 604 for (Node* node : graph->nodes()) { 605 if (nodes_to_rewrite->find(node->type_string()) != 606 nodes_to_rewrite->end() && 607 !IsGradientNode(graph, node)) { 608 // Find out which types are the inputs and convert them accordingly. 609 // 1. Const/Variable OP: This is quantized as signed tensors with no given 610 // range. 611 // 2. Activation OP: Set the range accordingly for different types of 612 // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh} 613 // 3. Identity OP: The quantization parameters depend on its input. 614 // 4. Pooling OPs: various pooling ops. Also depends on its input. 615 // 5. Reshape OP: Also depends on the first input to this op. 616 // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the 617 // model input. However, if there are >1 unknown ops, then returns an 618 // error for now to avoid unexpected bahavior. 619 // Note: The list above might not be a complete list. Please let us 620 // know if you see the error so we can handle your case. 621 for (const Edge* edge : node->in_edges()) { 622 if (edge->src_output() == Graph::kControlSlot) { 623 // Skip the control dependency input. 624 continue; 625 } else { 626 bool signed_input = false; 627 bool range_given = false; 628 float input_min = 0; 629 float input_max = 0; 630 bool known_op = FindType(graph, edge->src(), &signed_input, 631 &range_given, &input_min, &input_max); 632 if (!known_op) { 633 // Unknown op is considered as input. 634 potential_input++; 635 if (potential_input > kAllowedInputs) { 636 return errors::Unimplemented( 637 "Found an unknown op: ", edge->src()->name(), 638 " with type: ", edge->src()->type_string(), 639 "; Unknown ops are considered as model input for now and " 640 "only ", 641 kAllowedInputs, " inputs are supported currently."); 642 } 643 } 644 645 target_edges.emplace_back(EdgeToConvert( 646 edge, num_bits, signed_input, range_given, input_min, input_max)); 647 } 648 } 649 } 650 } 651 652 TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges)); 653 654 return Status::OK(); 655 } 656 657 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, 658 int32 num_bits, const string& quant_op_type, 659 GraphDef* result_graphdef) { 660 Graph graph(OpRegistry::Global()); 661 GraphConstructorOptions opts; 662 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph)); 663 664 // Call the rewriter on the graph. 665 TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph)); 666 667 // Convert the result graph back to a GraphDef. 668 graph.ToGraphDef(result_graphdef); 669 return Status::OK(); 670 } 671 672 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, 673 int32 num_bits, 674 const string& quant_op_type, 675 string* result_graph_string) { 676 // First create the graph from the GraphDef. 677 GraphDef input_graphdef; 678 if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { 679 return errors::InvalidArgument( 680 "input_graph_string is not a serialized GraphDef protocol buffer"); 681 } 682 GraphDef output_graphdef; 683 TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef( 684 input_graphdef, num_bits, quant_op_type, &output_graphdef)); 685 686 if (!output_graphdef.SerializeToString(result_graph_string)) { 687 return errors::Internal( 688 "quantize training transformation resulted in invalid GraphDef"); 689 } 690 return Status::OK(); 691 } 692 693 } // namespace tensorflow 694