1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <algorithm> 17 #include <atomic> 18 #include <set> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/graph/quantize_training.h" 23 24 #include "tensorflow/core/common_runtime/executor.h" 25 #include "tensorflow/core/common_runtime/function.h" 26 #include "tensorflow/core/common_runtime/memory_types.h" 27 #include "tensorflow/core/framework/log_memory.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/graph/algorithm.h" 30 #include "tensorflow/core/graph/graph_constructor.h" 31 #include "tensorflow/core/graph/node_builder.h" 32 #include "tensorflow/core/graph/subgraph.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 namespace { 38 39 // TODO(suharshs): If desired, make these values configurable. 40 const uint32 kAllowedInputs = 2; 41 const float kEMADecay = 0.999; 42 43 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs. 44 const auto* nodes_to_rewrite = 45 new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"}; 46 47 // Contains necessary parameters to convert an edge. 48 struct EdgeToConvert { 49 // edge is not owned here. 50 const Edge* edge; 51 int32 num_bits; 52 bool signed_input; 53 bool range_given; 54 float input_min; 55 float input_max; 56 57 EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min, 58 float max) 59 : edge(e), 60 num_bits(bits), 61 signed_input(sign), 62 range_given(range), 63 input_min(min), 64 input_max(max) {} 65 }; 66 67 // Decide if a node is in backward pass by checking if its name is led by 68 // "gradients". 69 // TODO(jmchen): Make this check more robust as it is not guaranteed that the 70 // forward node will not be named with a leading "gradients". 71 inline bool IsGradientNode(const Graph* graph, const Node* node) { 72 static const string tag = "gradients"; 73 return (node->name().compare(0, tag.size(), tag) == 0); 74 } 75 76 // Find the type of the input to set the parameters for the 77 // quantize_and_dequantize op. 78 // Returns true if the root tensor op type is known, false otherwise. 79 bool FindType(const Graph* graph, const Node* node, bool* signed_input, 80 bool* range_given, float* input_min, float* input_max) { 81 const string& src_op = node->type_string(); 82 if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") { 83 *signed_input = true; 84 *range_given = false; 85 } else if (src_op == "Relu") { 86 // Range is not given for Relu. 87 *signed_input = false; 88 *range_given = false; 89 } else if (src_op == "Relu6") { 90 // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the 91 // actual activations are somewhere in within this range, we can quantize 92 // this even further. This is true for other activations like Sigmoid6 too. 93 *signed_input = false; 94 *range_given = true; 95 *input_min = 0; 96 *input_max = 6; 97 } else if (src_op == "Sigmoid") { 98 *signed_input = false; 99 *range_given = true; 100 *input_min = 0; 101 *input_max = 1; 102 } else if (src_op == "Tanh") { 103 *signed_input = true; 104 *range_given = true; 105 *input_min = -1; 106 *input_max = 1; 107 } else if (src_op == "Reshape" || src_op == "ConcatV2") { 108 // Reshape has 2 inputs and the first one is the tensor. 109 // ConcatV2 has many inputs but they should all have the same activation 110 // function (i.e. Inception). So we just recurse on the first input. 111 for (const Edge* edge : node->in_edges()) { 112 if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) { 113 FindType(graph, edge->src(), signed_input, range_given, input_min, 114 input_max); 115 } 116 } 117 } else if (src_op == "Identity" || src_op == "MaxPool" || 118 src_op == "AvgPool" || src_op == "MaxPool3D" || 119 src_op == "AvgPool3D") { 120 // All these Ops only have 1 data input. 121 for (const Edge* edge : node->in_edges()) { 122 if (edge->src_output() != Graph::kControlSlot) { 123 FindType(graph, edge->src(), signed_input, range_given, input_min, 124 input_max); 125 } 126 } 127 } else { 128 // Unknown type, could be the model input examples. 129 // TODO(jmchen): Set the params for input with user's hint. 130 *signed_input = true; 131 *range_given = false; 132 return false; 133 } 134 135 return true; 136 } 137 138 // Find the Save op and inputs. 139 Status FindSaveOp(const Graph* graph, Node** save_op, 140 std::vector<const Edge*>* in_edges, bool* found) { 141 *found = false; 142 for (Node* node : graph->op_nodes()) { 143 if (node->type_string() == "SaveV2") { 144 // We found multiple save ops. 145 if (*found) { 146 return errors::InvalidArgument("Input graph has multiple SaveV2 ops."); 147 } 148 *save_op = node; 149 *found = true; 150 TF_RETURN_IF_ERROR(node->input_edges(in_edges)); 151 } 152 } 153 return Status::OK(); 154 } 155 156 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { 157 for (Node* node : graph->op_nodes()) { 158 // The restore_all op should have the same prefix of the save_op. 159 if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { 160 return node; 161 } 162 } 163 return nullptr; 164 } 165 166 // Strips the last "/suffix" from a name. 167 // We use this to construct the name of restore ops in the same way they are 168 // constructed by the Saver. 169 StringPiece GetNodeNamePrefix(const Node* node) { 170 StringPiece name = node->name(); 171 return name.substr(0, name.rfind('/')); 172 } 173 174 void FillStringTensor(Tensor* dst, const Tensor& src) { 175 auto dst_flat = dst->flat<string>(); 176 auto src_flat = src.flat<string>(); 177 for (int i = 0; i < src.NumElements(); i++) { 178 dst_flat(i) = src_flat(i); 179 } 180 } 181 182 // Add the added_variables as an inputs to the Save op. 183 // We change the inputs of the SaveV2 op to include the names of the added 184 // variables. We also add the variables as inputs to the save op. 185 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, 186 const std::vector<const Edge*>& in_edges, 187 const std::vector<Node*>& added_variables) { 188 Node* tensor_names_op = in_edges[1]->src(); 189 Node* shape_and_slices_op = in_edges[2]->src(); 190 191 // Get the tensor_names and shape_and_slices tensors from the const op. 192 Tensor tensor_names; 193 Tensor shape_and_slices; 194 TF_RETURN_IF_ERROR( 195 GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names)); 196 TF_RETURN_IF_ERROR( 197 GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices)); 198 199 int tn_size = tensor_names.NumElements(); 200 int var_size = added_variables.size(); 201 202 // Create a new save_op that has inputs to all the new variables. 203 NodeBuilder save_op_builder = 204 NodeBuilder(save_op->name(), save_op->type_string()); 205 // The first three inputs are prefix, tensor_names, and shapes_and_slices. 206 for (int i = 0; i < 3; i++) { 207 save_op_builder = save_op_builder.Input(in_edges[i]->src()); 208 } 209 std::vector<NodeBuilder::NodeOut> var_nodeouts; 210 var_nodeouts.reserve(tn_size + var_size); 211 // The rest of the inputs need to be used the construct the tensor list arg. 212 for (int i = 3; i < in_edges.size(); i++) { 213 var_nodeouts.emplace_back(in_edges[i]->src()); 214 } 215 216 // Add the new values to the tensors and the op input. 217 Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size})); 218 Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size})); 219 FillStringTensor(&new_tensor_names, tensor_names); 220 FillStringTensor(&new_shape_and_slices, shape_and_slices); 221 for (int i = 0; i < var_size; i++) { 222 Node* var = added_variables[i]; 223 new_tensor_names.flat<string>()(tn_size + i) = var->name(); 224 new_shape_and_slices.flat<string>()(tn_size + i) = ""; 225 var_nodeouts.emplace_back(var); 226 } 227 save_op_builder = save_op_builder.Input(var_nodeouts); 228 229 // Update the attrs. 230 tensor_names_op->AddAttr("value", new_tensor_names); 231 shape_and_slices_op->AddAttr("value", new_shape_and_slices); 232 233 // Remove the old save_op and add the new one. 234 Node* new_save_op; 235 TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op)); 236 // Add outputs to the new_save_op, all outputs are control edges. 237 for (const Edge* edge : save_op->out_edges()) { 238 graph->AddControlEdge(new_save_op, edge->dst()); 239 } 240 graph->RemoveNode(save_op); 241 242 return Status::OK(); 243 } 244 245 // Add a restore subgraph for each variable and connect to the restore_all op. 246 // For each variable we add the following subgraph: 247 // Assign----restore_all 248 // | | 249 // RestoreV2 Variable 250 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, 251 const std::vector<const Edge*>& in_edges, 252 const std::vector<Node*>& variables) { 253 Node* prefix_op = in_edges[0]->src(); 254 StringPiece name_prefix = GetNodeNamePrefix(save_op); 255 Node* restore_all = FindRestoreAllOp(graph, name_prefix); 256 if (restore_all == nullptr) { 257 return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp"); 258 } 259 const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2"); 260 const string assign_op_name = strings::StrCat(name_prefix, "/Assign"); 261 for (Node* var : variables) { 262 // Add an extra prefix after calling graph->NewName because the "unique" 263 // name may conflict with names generated for Send nodes. 264 // TODO(b/77547936): fix this more generally and get rid of the extra prefix 265 // here. 266 string new_restore_op_name = 267 strings::StrCat(graph->NewName(restore_op_name), "_qt"); 268 string new_assign_op_name = 269 strings::StrCat(graph->NewName(assign_op_name), "_qt"); 270 string tensor_names_op_name = 271 strings::StrCat(new_restore_op_name, "/tensor_names"); 272 string shape_and_slices_op_name = 273 strings::StrCat(new_restore_op_name, "/shape_and_slices"); 274 275 // Construct the tensor_names input with the variable name. 276 Node* tensor_names; 277 Tensor tensor_names_val(DT_STRING, TensorShape({1})); 278 tensor_names_val.flat<string>()(0) = var->name(); 279 TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const") 280 .Attr("dtype", DT_STRING) 281 .Attr("value", tensor_names_val) 282 .Finalize(graph, &tensor_names)); 283 284 // Construct the shape_and_slices input with empty string. 285 Node* shape_and_slices; 286 Tensor shape_and_slices_val(DT_STRING, TensorShape({1})); 287 shape_and_slices_val.flat<string>()(0) = ""; 288 TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const") 289 .Attr("dtype", DT_STRING) 290 .Attr("value", shape_and_slices_val) 291 .Finalize(graph, &shape_and_slices)); 292 293 // Build the new Restore op for this variable. 294 Node* restore_op; 295 TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2") 296 .Input(prefix_op) 297 .Input(tensor_names) 298 .Input(shape_and_slices) 299 .Attr("dtypes", {DT_FLOAT}) 300 .Finalize(graph, &restore_op)); 301 302 // Create Assign op, attaching the variable and Restore op to it. 303 Node* assign_op; 304 TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign") 305 .Input(var) 306 .Input(restore_op) 307 .Finalize(graph, &assign_op)); 308 309 // Add a control edge from the assign op to restore_all op. 310 graph->AddControlEdge(assign_op, restore_all); 311 } 312 return Status::OK(); 313 } 314 315 // Adds new variables to save and restore ops matching the Save and Restore 316 // graphs created in tensorflow/python/training/saver.py. 317 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) { 318 Node* save_op = nullptr; 319 std::vector<const Edge*> in_edges; 320 bool found = false; 321 TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found)); 322 if (found) { 323 TF_RETURN_IF_ERROR( 324 AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables)); 325 TF_RETURN_IF_ERROR( 326 ConnectVariablesToSaveOp(graph, save_op, in_edges, variables)); 327 } 328 return Status::OK(); 329 } 330 331 // Sets output to the Node that computes reduction axes corresponding to all 332 // dimensions of input and return. 333 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, 334 Node** output) { 335 name_prefix = strings::StrCat(name_prefix, "/ReductionAxes"); 336 Node* start; 337 Tensor zero_tensor(DT_INT32, TensorShape()); 338 zero_tensor.flat<int32>()(0) = 0; 339 TF_RETURN_IF_ERROR( 340 NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const") 341 .Attr("dtype", DT_INT32) 342 .Attr("value", zero_tensor) 343 .Finalize(graph, &start)); 344 Node* delta; 345 Tensor one_tensor(DT_INT32, TensorShape()); 346 one_tensor.flat<int32>()(0) = 1; 347 TF_RETURN_IF_ERROR( 348 NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const") 349 .Attr("dtype", DT_INT32) 350 .Attr("value", one_tensor) 351 .Finalize(graph, &delta)); 352 Node* rank; 353 TF_RETURN_IF_ERROR( 354 NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank") 355 .Input(input) 356 .Finalize(graph, &rank)); 357 TF_RETURN_IF_ERROR( 358 NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range") 359 .Input(start) 360 .Input(rank) 361 .Input(delta) 362 .Finalize(graph, output)); 363 return Status::OK(); 364 } 365 366 // Computes the exponential moving average of input, updated in update_variable. 367 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, 368 const NodeBuilder::NodeOut& input, 369 Node* decay, Node* update_variable, 370 Node** assign_value) { 371 // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)] 372 name_prefix = strings::StrCat(name_prefix, "/EMA"); 373 Node* one; 374 Tensor one_tensor(DT_FLOAT, TensorShape()); 375 one_tensor.flat<float>()(0) = 1.0; 376 TF_RETURN_IF_ERROR( 377 NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const") 378 .Attr("dtype", DT_FLOAT) 379 .Attr("value", one_tensor) 380 .Finalize(graph, &one)); 381 Node* decay_complement; 382 TF_RETURN_IF_ERROR( 383 NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub") 384 .Input(one) 385 .Input(decay) 386 .Finalize(graph, &decay_complement)); 387 388 Node* value_diff; 389 TF_RETURN_IF_ERROR( 390 NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub") 391 .Input(update_variable) 392 .Input(input) 393 .Finalize(graph, &value_diff)); 394 Node* update_value; 395 TF_RETURN_IF_ERROR( 396 NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul") 397 .Input(value_diff) 398 .Input(decay_complement) 399 .Finalize(graph, &update_value)); 400 401 TF_RETURN_IF_ERROR( 402 NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub") 403 .Input(update_variable) 404 .Input(update_value) 405 .Finalize(graph, assign_value)); 406 return Status::OK(); 407 } 408 409 // Creates an automatically initialized exponential moving average variable. 410 // This uses a switch op to assign a value to the variable on the first run, 411 // and update with the moving average for all other runs: 412 // init_val 413 // | 414 // var--is_init--switch 415 // | true / \ false 416 // | | | 417 // | EMA init_val 418 // | \ / 419 // +----------- assign 420 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, 421 Node* init_val, 422 std::vector<Node*>* added_variables, 423 Node** var) { 424 // TODO(suharshs): Update this to use ResourceVariables when they are ready. 425 TF_RETURN_IF_ERROR( 426 NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2") 427 .Attr("shape", TensorShape()) 428 .Attr("dtype", DT_FLOAT) 429 .Finalize(graph, var)); 430 added_variables->push_back(*var); 431 432 Node* is_initialized; 433 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"), 434 "IsVariableInitialized") 435 .Input(*var) 436 .Finalize(graph, &is_initialized)); 437 Node* switch_node; 438 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch") 439 .Input(init_val) 440 .Input(is_initialized) 441 .Finalize(graph, &switch_node)); 442 NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0); 443 NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1); 444 445 Node* ema_value; 446 TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true, 447 decay, *var, &ema_value)); 448 449 Node* assign_value; 450 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge") 451 .Input({output_false, ema_value}) 452 .Finalize(graph, &assign_value)); 453 454 TF_RETURN_IF_ERROR( 455 NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign") 456 .Input(*var) 457 .Input(assign_value) 458 .Finalize(graph, var)); 459 return Status::OK(); 460 } 461 462 // Computes the min and max EMA of input and stores them in min_var and max_var. 463 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, 464 std::vector<Node*>* added_variables, Node** min_var, 465 Node** max_var) { 466 // TODO(suharshs): The decay will be constant, so we could make only one for 467 // all quantize_and_dequantize ops to share, this would have to live outside 468 // this function. 469 Tensor decay_tensor(DT_FLOAT, TensorShape()); 470 decay_tensor.flat<float>()(0) = kEMADecay; 471 Node* decay; 472 TF_RETURN_IF_ERROR( 473 NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const") 474 .Attr("dtype", DT_FLOAT) 475 .Attr("value", decay_tensor) 476 .Finalize(graph, &decay)); 477 478 Node* reduction_axes; 479 TF_RETURN_IF_ERROR( 480 MakeReductionAxes(graph, name_prefix, input, &reduction_axes)); 481 Node* min; 482 string min_name = strings::StrCat(name_prefix, "/Min"); 483 TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min") 484 .Input(input) 485 .Input(reduction_axes) 486 .Finalize(graph, &min)); 487 Node* max; 488 string max_name = strings::StrCat(name_prefix, "/Max"); 489 TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max") 490 .Input(input) 491 .Input(reduction_axes) 492 .Finalize(graph, &max)); 493 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min, 494 added_variables, min_var)); 495 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max, 496 added_variables, max_var)); 497 return Status::OK(); 498 } 499 500 // Makes an input min and max constant if the range is given. Otherwise, makes 501 // min and max variables that are updated by an EMA. 502 Status MakeInputMinMax(Graph* graph, const string& name_prefix, 503 const EdgeToConvert& edge, 504 std::vector<Node*>* added_variables, Node** input_min, 505 Node** input_max) { 506 if (edge.range_given) { 507 // Make constant nodes for the input_min and input_max if the range is 508 // provided. 509 Tensor input_min_tensor(DT_FLOAT, TensorShape()); 510 input_min_tensor.flat<float>()(0) = edge.input_min; 511 TF_RETURN_IF_ERROR( 512 NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const") 513 .Attr("dtype", DT_FLOAT) 514 .Attr("value", input_min_tensor) 515 .Finalize(graph, input_min)); 516 Tensor input_max_tensor(DT_FLOAT, TensorShape()); 517 input_max_tensor.flat<float>()(0) = edge.input_max; 518 TF_RETURN_IF_ERROR( 519 NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const") 520 .Attr("dtype", DT_FLOAT) 521 .Attr("value", input_max_tensor) 522 .Finalize(graph, input_max)); 523 } else { 524 // If the range is not given, estimate the range with EMA variables. 525 TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(), 526 added_variables, input_min, 527 input_max)); 528 } 529 530 return Status::OK(); 531 } 532 533 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op 534 // (and required input nodes) based on edge. 535 // The result is stored in convert_node. 536 Status MakeQuantizeOp(Graph* graph, const string& name_prefix, 537 const string& quant_op_type, const EdgeToConvert& edge, 538 std::vector<Node*>* added_variables, 539 Node** convert_node) { 540 Node* input_min; 541 Node* input_max; 542 TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, 543 &input_min, &input_max)); 544 string quant_name = strings::StrCat(name_prefix, "/", quant_op_type); 545 if (quant_op_type == "QuantizeAndDequantizeV2") { 546 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) 547 .Input(edge.edge->src()) 548 .Input(input_min) 549 .Input(input_max) 550 .Attr("signed_input", edge.signed_input) 551 .Attr("num_bits", edge.num_bits) 552 .Attr("range_given", true) 553 .Finalize(graph, convert_node)); 554 } else if (quant_op_type == "FakeQuantWithMinMaxVars") { 555 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) 556 .Input(edge.edge->src()) 557 .Input(input_min) 558 .Input(input_max) 559 .Attr("num_bits", edge.num_bits) 560 .Finalize(graph, convert_node)); 561 } else { 562 return errors::InvalidArgument("Unknown quant op type: ", quant_op_type); 563 } 564 return Status::OK(); 565 } 566 567 // Insert conversion op, connect it to the graph and remove the old edge. 568 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, 569 const std::vector<EdgeToConvert>& target_edges) { 570 // Remember previously converted ops to avoid duplicated conversion on the 571 // same input. 572 std::unordered_map<string, Node*, StringPieceHasher> name_index; 573 std::vector<Node*> added_variables; 574 for (const EdgeToConvert edge : target_edges) { 575 Node* convert_node; 576 string name_prefix = edge.edge->src()->name(); 577 578 auto iter = name_index.find(name_prefix); 579 if (iter == name_index.end()) { 580 TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge, 581 &added_variables, &convert_node)); 582 name_index[name_prefix] = convert_node; 583 } else { 584 convert_node = iter->second; 585 } 586 587 graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input()); 588 graph->RemoveEdge(edge.edge); 589 } 590 591 TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables)); 592 593 return Status::OK(); 594 } 595 596 } // namespace 597 598 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type, 599 Graph* graph) { 600 if (graph == nullptr) { 601 return errors::InvalidArgument("Cannot accept empty graph pointer."); 602 } 603 604 if (num_bits < 1 || num_bits > 63) { 605 return errors::OutOfRange("num_bits should be in range [1, 63] but is: ", 606 num_bits); 607 } 608 int potential_input = 0; 609 std::vector<EdgeToConvert> target_edges; 610 for (Node* node : graph->nodes()) { 611 if (nodes_to_rewrite->find(node->type_string()) != 612 nodes_to_rewrite->end() && 613 !IsGradientNode(graph, node)) { 614 // Find out which types are the inputs and convert them accordingly. 615 // 1. Const/Variable OP: This is quantized as signed tensors with no given 616 // range. 617 // 2. Activation OP: Set the range accordingly for different types of 618 // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh} 619 // 3. Identity OP: The quantization parameters depend on its input. 620 // 4. Pooling OPs: various pooling ops. Also depends on its input. 621 // 5. Reshape OP: Also depends on the first input to this op. 622 // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the 623 // model input. However, if there are >1 unknown ops, then returns an 624 // error for now to avoid unexpected behavior. 625 // Note: The list above might not be a complete list. Please let us 626 // know if you see the error so we can handle your case. 627 for (const Edge* edge : node->in_edges()) { 628 if (edge->src_output() == Graph::kControlSlot) { 629 // Skip the control dependency input. 630 continue; 631 } else { 632 bool signed_input = false; 633 bool range_given = false; 634 float input_min = 0; 635 float input_max = 0; 636 bool known_op = FindType(graph, edge->src(), &signed_input, 637 &range_given, &input_min, &input_max); 638 if (!known_op) { 639 // Unknown op is considered as input. 640 potential_input++; 641 if (potential_input > kAllowedInputs) { 642 return errors::Unimplemented( 643 "Found an unknown op: ", edge->src()->name(), 644 " with type: ", edge->src()->type_string(), 645 "; Unknown ops are considered as model input for now and " 646 "only ", 647 kAllowedInputs, " inputs are supported currently."); 648 } 649 } 650 651 target_edges.emplace_back(EdgeToConvert( 652 edge, num_bits, signed_input, range_given, input_min, input_max)); 653 } 654 } 655 } 656 } 657 658 TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges)); 659 660 return Status::OK(); 661 } 662 663 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, 664 int32 num_bits, const string& quant_op_type, 665 GraphDef* result_graphdef) { 666 Graph graph(OpRegistry::Global()); 667 GraphConstructorOptions opts; 668 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph)); 669 670 // Call the rewriter on the graph. 671 TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph)); 672 673 // Convert the result graph back to a GraphDef. 674 graph.ToGraphDef(result_graphdef); 675 return Status::OK(); 676 } 677 678 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, 679 int32 num_bits, 680 const string& quant_op_type, 681 string* result_graph_string) { 682 // First create the graph from the GraphDef. 683 GraphDef input_graphdef; 684 if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { 685 return errors::InvalidArgument( 686 "input_graph_string is not a serialized GraphDef protocol buffer"); 687 } 688 GraphDef output_graphdef; 689 TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef( 690 input_graphdef, num_bits, quant_op_type, &output_graphdef)); 691 692 if (!output_graphdef.SerializeToString(result_graph_string)) { 693 return errors::Internal( 694 "quantize training transformation resulted in invalid GraphDef"); 695 } 696 return Status::OK(); 697 } 698 699 } // namespace tensorflow 700