1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/common_runtime/constant_folding.h" 19 #include "tensorflow/core/common_runtime/threadpool_device.h" 20 #include "tensorflow/core/graph/graph_constructor.h" 21 #include "tensorflow/core/graph/node_builder.h" 22 #include "tensorflow/core/graph/subgraph.h" 23 #include "tensorflow/core/kernels/quantization_utils.h" 24 #include "tensorflow/core/platform/init_main.h" 25 #include "tensorflow/core/public/session.h" 26 #include "tensorflow/tools/graph_transforms/transform_utils.h" 27 28 namespace tensorflow { 29 namespace graph_transforms { 30 31 // Holds the information we need to translate from a float version of this op 32 // into the quantized equivalent. 33 struct QuantizedOpInfo { 34 // The name of the float op. 35 string float_name; 36 // Which attributes to copy directly over. 37 std::vector<string> attrs_to_copy; 38 // Extra data type attributes we need to set. 39 std::vector<std::pair<string, DataType>> dtypes_to_set; 40 // What depth of inputs the op can read in. 41 DataType input_bit_depth; 42 // The depth of the op's quantized outputs. 43 DataType output_bit_depth; 44 // Which inputs (e.g. shapes) aren't involved in the quantization process. 45 std::set<int32> unquantized_inputs; 46 // How the outputs are arranged, either 47 // [input0, input1, min0, max0, min1, max1] for contiguous, or 48 // [input0, input1, min0, min1, max0, max1] for separate. 49 // The separate order is needed because it's the only way to specify unknown 50 // numbers of inputs for ops like Concat. 51 enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order; 52 }; 53 54 // Every op that has a quantized equivalent should be listed here, so that the 55 // conversion process can transform them. 56 const std::vector<QuantizedOpInfo>& GetQuantizedOpList() { 57 static const std::vector<QuantizedOpInfo> op_list = { 58 {"Add", 59 {}, 60 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, 61 DT_QUINT8, 62 DT_QINT32, 63 {}, 64 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 65 {"AvgPool", 66 {"ksize", "strides", "padding"}, 67 {{"T", DT_QUINT8}}, 68 DT_QUINT8, 69 DT_QUINT8, 70 {}, 71 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 72 {"BiasAdd", 73 {}, 74 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}}, 75 DT_QUINT8, 76 DT_QINT32, 77 {}, 78 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 79 {"Concat", 80 {"N"}, 81 {{"T", DT_QUINT8}}, 82 DT_QUINT8, 83 DT_QUINT8, 84 {0}, 85 QuantizedOpInfo::SEPARATE_MIN_MAX}, 86 {"Conv2D", 87 {"strides", "padding"}, 88 {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}}, 89 DT_QUINT8, 90 DT_QINT32, 91 {}, 92 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 93 {"MatMul", 94 {"transpose_a", "transpose_b"}, 95 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, 96 DT_QUINT8, 97 DT_QINT32, 98 {}, 99 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 100 {"MaxPool", 101 {"ksize", "strides", "padding"}, 102 {{"T", DT_QUINT8}}, 103 DT_QUINT8, 104 DT_QUINT8, 105 {}, 106 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 107 {"Mul", 108 {}, 109 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, 110 DT_QUINT8, 111 DT_QINT32, 112 {}, 113 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 114 {"Relu", 115 {}, 116 {{"Tinput", DT_QUINT8}}, 117 DT_QUINT8, 118 DT_QUINT8, 119 {}, 120 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 121 {"ResizeBilinear", 122 {"align_corners"}, 123 {{"T", DT_QUINT8}}, 124 DT_QUINT8, 125 DT_QUINT8, 126 {1}, 127 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 128 {"Relu6", 129 {}, 130 {{"Tinput", DT_QUINT8}}, 131 DT_QUINT8, 132 DT_QUINT8, 133 {}, 134 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 135 {"Reshape", 136 {}, 137 {{"T", DT_QUINT8}}, 138 DT_QUINT8, 139 DT_QUINT8, 140 {1}, 141 QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, 142 }; 143 return op_list; 144 } 145 146 namespace { 147 // Replaces invalid characters in input names to get a unique node name. 148 string UniqueNodeNameFromInput(const string& input_name) { 149 string prefix; 150 string node_name; 151 string suffix; 152 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); 153 string result; 154 if (prefix == "^") { 155 result += "__hat__"; 156 } 157 result += node_name; 158 if (!suffix.empty()) { 159 result += "__port__" + suffix.substr(1, suffix.size() - 1); 160 } 161 return result; 162 } 163 164 // Pulls two float values from the named parameters, with a lot of checking. 165 Status ExtractRangeFromParams(const TransformFuncContext& context, 166 const string& min_name, const string& max_name, 167 float* min_value, float* max_value, 168 bool* has_range) { 169 // See if we've been given quantized inputs with a known range. 170 const bool has_min = (context.params.count(min_name) != 0); 171 const bool has_max = (context.params.count(max_name) != 0); 172 *has_range = (has_min || has_max); 173 if (!*has_range) { 174 return Status::OK(); 175 } 176 if (!has_min || !has_max) { 177 return errors::InvalidArgument("You must pass both ", min_name, " and ", 178 max_name, " into quantize_nodes"); 179 } 180 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value)); 181 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value)); 182 return Status::OK(); 183 } 184 185 } // namespace 186 187 // Analyzes all the nodes in the graph to figure out which ones are duplicates 188 // apart from their names. This commonly includes identical Const nodes, but can 189 // also be simple operations that are repeated on multiple outputs of a 190 // particular node. The complexity is managed using a hash function that avoids 191 // the need for any O(n^2) algorithms when identifying duplicates. 192 Status MergeDuplicateNodes(const GraphDef& input_graph_def, 193 const TransformFuncContext& context, 194 GraphDef* output_graph_def) { 195 // Make sure we can look up inputs and outputs quickly. 196 std::set<string> input_names(context.input_names.begin(), 197 context.input_names.end()); 198 std::set<string> output_names(context.output_names.begin(), 199 context.output_names.end()); 200 GraphDef current_graph_def = input_graph_def; 201 // Keep running the merging until no more duplicates are found. 202 bool any_duplicates_found; 203 do { 204 any_duplicates_found = false; 205 // First arrange all of the nodes by a hash of their contents. 206 std::map<uint64, std::vector<const NodeDef*>> hashed_nodes; 207 for (const NodeDef& node : current_graph_def.node()) { 208 NodeDef nameless_node = node; 209 // The name matters if it's being used as an input or output node, 210 // otherwise ignore it when looking for duplicates. 211 if (!input_names.count(node.name()) && !output_names.count(node.name())) { 212 nameless_node.set_name(""); 213 } 214 const uint64 hash = HashNodeDef(nameless_node); 215 hashed_nodes[hash].push_back(&node); 216 } 217 // If we have multiple nodes with the same hash, then we know they're 218 // duplicates and can be removed, unless they're stateful. 219 std::map<string, string> inputs_to_rename; 220 GraphDef merged_graph_def; 221 for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info : 222 hashed_nodes) { 223 const std::vector<const NodeDef*>& hash_node_list = 224 hashed_node_info.second; 225 for (int i = 0; i < hash_node_list.size(); ++i) { 226 const NodeDef* current_node = hash_node_list[i]; 227 const OpDef* op_def = nullptr; 228 TF_RETURN_IF_ERROR( 229 OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def)); 230 const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0)); 231 if (is_duplicate) { 232 const string original_name = hash_node_list[0]->name(); 233 inputs_to_rename[current_node->name() + ":*"] = original_name; 234 any_duplicates_found = true; 235 } else { 236 NodeDef* new_node = merged_graph_def.mutable_node()->Add(); 237 *new_node = *current_node; 238 } 239 } 240 } 241 // Update the graph so that any nodes that referred to removed inputs now 242 // pull from the remaining duplicate. 243 TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename, 244 std::unordered_set<string>(), 245 ¤t_graph_def)); 246 } while (any_duplicates_found); 247 248 *output_graph_def = current_graph_def; 249 250 return Status::OK(); 251 } 252 253 // Looks for the patterns that indicate there are two eight-bit ops feeding into 254 // each other, separated by a conversion up to float and back again. These occur 255 // during the initial conversion of ops to their quantized forms. Because we're 256 // only looking at an individual op in that phase and don't know if its inputs 257 // and outputs are eight-bit-capable, we start by converting the actual op into 258 // quantized form, but add float conversions before and after. This pass gets 259 // rid of those conversions if it turns out we do have adjacent ops capable of 260 // eight-bit processing. 261 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, 262 const TransformFuncContext& context, 263 GraphDef* output_graph_def) { 264 std::set<string> graph_outputs; 265 for (const string& output_name : context.output_names) { 266 graph_outputs.insert(NodeNameFromInput(output_name)); 267 } 268 std::map<string, string> inputs_to_rename; 269 GraphDef replaced_graph_def; 270 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 271 input_graph_def, // clang-format off 272 {"QuantizeV2", 273 { 274 {"Dequantize"}, 275 {"Min"}, 276 {"Max"}, 277 } 278 }, // clang-format on 279 [&inputs_to_rename, &graph_outputs](const NodeMatch& match, 280 const std::set<string>& input_nodes, 281 const std::set<string>& output_nodes, 282 std::vector<NodeDef>* new_nodes) { 283 const NodeDef& quantize_node = match.node; 284 const NodeDef& dequantize_node = match.inputs[0].node; 285 inputs_to_rename[quantize_node.name() + ":0"] = 286 dequantize_node.input(0); 287 inputs_to_rename[quantize_node.name() + ":1"] = 288 dequantize_node.input(1); 289 inputs_to_rename[quantize_node.name() + ":2"] = 290 dequantize_node.input(2); 291 292 // Are other sub-graphs using the float intermediate result? If so, 293 // preserve it, but the input renaming still rewires the eight-bit ops 294 // so they don't go through float. 295 if (output_nodes.count(dequantize_node.name()) || 296 graph_outputs.count(dequantize_node.name())) { 297 CopyOriginalMatch(match, new_nodes); 298 } 299 300 return Status::OK(); 301 }, 302 {true}, &replaced_graph_def)); 303 304 return RenameNodeInputs(replaced_graph_def, inputs_to_rename, 305 std::unordered_set<string>(), output_graph_def); 306 } 307 308 // If the user has passed in the input_min and input_max args, then we need to 309 // convert any input placeholders from float to eight bit, so quantized inputs 310 // can be fed directly into the graph. 311 Status QuantizePlaceholders(const GraphDef& input_graph_def, 312 const TransformFuncContext& context, 313 GraphDef* output_graph_def) { 314 float input_min; 315 float input_max; 316 bool has_input_range; 317 TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max", 318 &input_min, &input_max, 319 &has_input_range)); 320 if (!has_input_range) { 321 *output_graph_def = input_graph_def; 322 return Status::OK(); 323 } 324 std::map<string, string> inputs_to_rename_first_pass; 325 std::map<string, string> inputs_to_rename_second_pass; 326 GraphDef placeholder_graph_def; 327 placeholder_graph_def.Clear(); 328 for (const NodeDef& node : input_graph_def.node()) { 329 if (node.op() != "Placeholder") { 330 *(placeholder_graph_def.mutable_node()->Add()) = node; 331 } else { 332 string namespace_prefix = node.name() + "_eightbit"; 333 334 NodeDef quantized_placeholder; 335 quantized_placeholder = node; 336 SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder); 337 *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder; 338 339 NodeDef min_node; 340 min_node.set_op("Const"); 341 min_node.set_name(namespace_prefix + "/min"); 342 SetNodeAttr("dtype", DT_FLOAT, &min_node); 343 Tensor min_tensor(DT_FLOAT, {}); 344 min_tensor.flat<float>()(0) = input_min; 345 SetNodeTensorAttr<float>("value", min_tensor, &min_node); 346 *(placeholder_graph_def.mutable_node()->Add()) = min_node; 347 348 NodeDef max_node; 349 max_node.set_op("Const"); 350 max_node.set_name(namespace_prefix + "/max"); 351 SetNodeAttr("dtype", DT_FLOAT, &max_node); 352 Tensor max_tensor(DT_FLOAT, {}); 353 max_tensor.flat<float>()(0) = input_max; 354 SetNodeTensorAttr<float>("value", max_tensor, &max_node); 355 *(placeholder_graph_def.mutable_node()->Add()) = max_node; 356 357 const string rename_suffix = "__RENAMED_PLACEHOLDER__"; 358 NodeDef dequantize_node; 359 dequantize_node.set_op("Dequantize"); 360 dequantize_node.set_name(namespace_prefix + "/dequantize"); 361 SetNodeAttr("T", DT_QUINT8, &dequantize_node); 362 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); 363 AddNodeInput(node.name() + rename_suffix, &dequantize_node); 364 AddNodeInput(min_node.name(), &dequantize_node); 365 AddNodeInput(max_node.name(), &dequantize_node); 366 *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node; 367 368 // First make sure that any internal references to the old placeholder 369 // now point to the dequantize result. 370 inputs_to_rename_first_pass[node.name()] = dequantize_node.name(); 371 // Then fix up the dequantize op so that it really points to the 372 // placeholder. 373 inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name(); 374 } 375 } 376 377 GraphDef first_pass_graph_def; 378 TF_RETURN_IF_ERROR( 379 RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass, 380 std::unordered_set<string>(), &first_pass_graph_def)); 381 TF_RETURN_IF_ERROR( 382 RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass, 383 std::unordered_set<string>(), output_graph_def)); 384 385 return Status::OK(); 386 } 387 388 // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for 389 // an activation layer. To use these during inference, this pass converts those 390 // ops into Requantizes with the trained min/maxes as constant inputs. 391 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, 392 const TransformFuncContext& context, 393 GraphDef* output_graph_def) { 394 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 395 input_graph_def, // clang-format off 396 {"FakeQuantWithMinMaxVars", 397 { 398 {"*"}, 399 {"Const"}, 400 {"Const"}, 401 } 402 }, // clang-format on 403 [](const NodeMatch& match, const std::set<string>& input_nodes, 404 const std::set<string>& output_nodes, 405 std::vector<NodeDef>* new_nodes) { 406 const NodeDef& fake_quant_node = match.node; 407 const NodeDef& original_op_node = match.inputs[0].node; 408 const NodeDef& fake_quant_min_node = match.inputs[1].node; 409 const NodeDef& fake_quant_max_node = match.inputs[2].node; 410 411 string namespace_prefix = fake_quant_node.name() + "_eightbit"; 412 413 new_nodes->push_back(original_op_node); 414 new_nodes->push_back(fake_quant_min_node); 415 new_nodes->push_back(fake_quant_max_node); 416 417 NodeDef quantize_node; 418 quantize_node.set_op("QuantizeV2"); 419 quantize_node.set_name(namespace_prefix + "/quantize"); 420 SetNodeAttr("T", DT_QINT32, &quantize_node); 421 SetNodeAttr("mode", "MIN_FIRST", &quantize_node); 422 AddNodeInput(fake_quant_node.input(0), &quantize_node); 423 AddNodeInput(fake_quant_min_node.name(), &quantize_node); 424 AddNodeInput(fake_quant_max_node.name(), &quantize_node); 425 new_nodes->push_back(quantize_node); 426 427 NodeDef requantize_node; 428 requantize_node.set_op("Requantize"); 429 requantize_node.set_name(namespace_prefix + "/requantize"); 430 SetNodeAttr("Tinput", DT_QINT32, &requantize_node); 431 SetNodeAttr("out_type", DT_QUINT8, &requantize_node); 432 AddNodeInput(quantize_node.name() + ":0", &requantize_node); 433 AddNodeInput(quantize_node.name() + ":1", &requantize_node); 434 AddNodeInput(quantize_node.name() + ":2", &requantize_node); 435 AddNodeInput(fake_quant_min_node.name(), &requantize_node); 436 AddNodeInput(fake_quant_max_node.name(), &requantize_node); 437 new_nodes->push_back(requantize_node); 438 439 // Convert the 8-bit result back into float for the final output. 440 NodeDef dequantize_node; 441 dequantize_node.set_op("Dequantize"); 442 dequantize_node.set_name(fake_quant_node.name()); 443 SetNodeAttr("T", DT_QUINT8, &dequantize_node); 444 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); 445 AddNodeInput(requantize_node.name() + ":0", &dequantize_node); 446 AddNodeInput(requantize_node.name() + ":1", &dequantize_node); 447 AddNodeInput(requantize_node.name() + ":2", &dequantize_node); 448 new_nodes->push_back(dequantize_node); 449 450 return Status::OK(); 451 }, 452 {}, output_graph_def)); 453 454 return Status::OK(); 455 } 456 457 // We always generate Requantize ops driven by dynamic RequantizationRange 458 // calculations when we produce quantized ops like Conv2D or BiasAdd with 459 // 32-bit results. If there were FakeQuant ops already for those activation 460 // layers, then there will be a later Requantize op with constant min/max 461 // inputs, which is preferable for fast inference. This pass looks for those 462 // later Requantize ops, and replaces the dynamic version with them. 463 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, 464 const TransformFuncContext& context, 465 GraphDef* output_graph_def) { 466 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 467 input_graph_def, // clang-format off 468 {"Requantize", 469 { 470 {"QuantizeV2", 471 { 472 {"Dequantize", 473 { 474 {"Requantize", 475 { 476 {"*"}, 477 {"*"}, 478 {"*"}, 479 {"RequantizationRange"}, 480 {"RequantizationRange"}, 481 } 482 }, 483 {"Requantize"}, 484 {"Requantize"}, 485 } 486 }, 487 {"Const"}, 488 {"Const"}, 489 }, 490 }, 491 {"QuantizeV2"}, 492 {"QuantizeV2"}, 493 {"Const"}, 494 {"Const"}, 495 } 496 }, // clang-format on 497 [](const NodeMatch& match, const std::set<string>& input_nodes, 498 const std::set<string>& output_nodes, 499 std::vector<NodeDef>* new_nodes) { 500 const NodeDef& fake_requantize_node = match.node; 501 const NodeDef& original_op_node = 502 match.inputs[0].inputs[0].inputs[0].inputs[0].node; 503 const NodeDef& fake_requantize_min_node = match.inputs[3].node; 504 const NodeDef& fake_requantize_max_node = match.inputs[4].node; 505 506 new_nodes->push_back(original_op_node); 507 new_nodes->push_back(fake_requantize_min_node); 508 new_nodes->push_back(fake_requantize_max_node); 509 510 NodeDef requantize_node; 511 requantize_node = fake_requantize_node; 512 requantize_node.mutable_input()->Clear(); 513 AddNodeInput(original_op_node.name() + ":0", &requantize_node); 514 AddNodeInput(original_op_node.name() + ":1", &requantize_node); 515 AddNodeInput(original_op_node.name() + ":2", &requantize_node); 516 AddNodeInput(fake_requantize_min_node.name(), &requantize_node); 517 AddNodeInput(fake_requantize_max_node.name(), &requantize_node); 518 new_nodes->push_back(requantize_node); 519 520 return Status::OK(); 521 }, 522 {}, output_graph_def)); 523 524 return Status::OK(); 525 } 526 527 // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of 528 // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd 529 // op that we want to apply the trained constant conversions to. This pass tries 530 // to move FakeQuant ops up the input chain, so they're as close as possible to 531 // the 32-bit conversion, and so can be easily merged into the automatic dynamic 532 // Requantizes. 533 Status HoistFakeQuants(const GraphDef& input_graph_def, 534 const TransformFuncContext& context, 535 GraphDef* output_graph_def) { 536 GraphDef current_graph_def = input_graph_def; 537 const int max_depth = 3; 538 for (int depth = max_depth; depth > 0; --depth) { 539 OpTypePattern pattern = {"*"}; 540 for (int i = 0; i < depth; ++i) { 541 pattern = {"*", {pattern}}; 542 } 543 pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}}; 544 GraphDef hoisted_graph_def; 545 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 546 current_graph_def, pattern, 547 [depth](const NodeMatch& match, const std::set<string>& input_nodes, 548 const std::set<string>& output_nodes, 549 std::vector<NodeDef>* new_nodes) { 550 const NodeDef& fake_quant_node = match.node; 551 const NodeDef& fake_quant_min_node = match.inputs[1].node; 552 const NodeDef& fake_quant_max_node = match.inputs[2].node; 553 std::vector<NodeDef> linear_nodes; 554 NodeMatch current_match = match; 555 for (int i = 0; i <= depth; ++i) { 556 linear_nodes.push_back(current_match.inputs[0].node); 557 current_match = current_match.inputs[0]; 558 } 559 NodeDef new_fake_quant_node; 560 new_fake_quant_node = fake_quant_node; 561 new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted"); 562 new_fake_quant_node.set_input( 563 0, linear_nodes[linear_nodes.size() - 2].input(0)); 564 new_nodes->push_back(new_fake_quant_node); 565 566 new_nodes->push_back(fake_quant_min_node); 567 new_nodes->push_back(fake_quant_max_node); 568 569 linear_nodes[linear_nodes.size() - 2].set_input( 570 0, new_fake_quant_node.name()); 571 linear_nodes.front().set_name(fake_quant_node.name()); 572 for (const NodeDef& linear_node : linear_nodes) { 573 new_nodes->push_back(linear_node); 574 } 575 576 return Status::OK(); 577 }, 578 {}, &hoisted_graph_def)); 579 current_graph_def = hoisted_graph_def; 580 } 581 *output_graph_def = current_graph_def; 582 583 return Status::OK(); 584 } 585 586 // Converts any float ops that have eight-bit equivalents into their quantized 587 // forms, so that as much calculation as possible is done in the lower-precision 588 // format. 589 Status QuantizeNodes(const GraphDef& input_graph_def, 590 const TransformFuncContext& context, 591 GraphDef* output_graph_def) { 592 // Loop through all of the quantizable op types, and replace any occurrences 593 // with equivalent sub-graphs with quantized ops at their core. For example 594 // this one-input operation: 595 // 596 // Input(float) 597 // | 598 // v 599 // Operation 600 // | 601 // v 602 // (float) 603 // 604 // Will be turned into it's quantized equivalent: 605 // 606 // Input(float) ReshapeDims 607 // +------v v-------------+ 608 // | Reshape 609 // | | 610 // | | ReductionDims 611 // | +-----+ | 612 // | | +---c---------+ 613 // | v v v v-------+ 614 // | Min Max 615 // | +----+ | 616 // v v v--------+ 617 // Quantize 618 // | 619 // v 620 // QuantizedOperation 621 // | | | 622 // v v v 623 // Dequantize 624 // | 625 // v 626 // (float) 627 // 628 // This keeps the inputs and outputs visible to the rest of the graph in 629 // float 630 // and converts them down to quantized buffers internally for the 631 // computation. 632 // The result will end up with a lot of redundant dequantize/quantize pairs 633 // between adjacent quantized ops, but a later pass removes these where it 634 // can. 635 636 std::set<string> ops_to_ignore; 637 if (context.params.count("ignore_op") > 0) { 638 for (const string& name : context.params.at("ignore_op")) { 639 ops_to_ignore.insert(name); 640 } 641 } 642 643 const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList(); 644 string op_pattern; 645 bool is_first = true; 646 std::map<string, QuantizedOpInfo> op_map; 647 for (const QuantizedOpInfo& op_info : op_list) { 648 if (ops_to_ignore.count(op_info.float_name) == 0) { 649 strings::StrAppend(&op_pattern, (is_first ? "" : "|"), 650 op_info.float_name); 651 op_map.insert({op_info.float_name, op_info}); 652 is_first = false; 653 } 654 } 655 656 // If input_min and input max have been passed in, then we convert all float 657 // Placeholder nodes into quantized versions, with the supplied values as 658 // their range. 659 GraphDef placeholder_graph_def; 660 TF_RETURN_IF_ERROR( 661 QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def)); 662 TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def)); 663 664 // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear 665 // operations like Relu or MaxPool, move them up so that they're as close as 666 // possible to ops with 32-bit outputs like BiasAdd or Conv2D. 667 GraphDef hoisted_graph_def; 668 TF_RETURN_IF_ERROR( 669 HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def)); 670 TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def)); 671 672 // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of 673 // activation layers, into Requantize ops with those ranges instead. This 674 // makes it easier to replace the dynamic range calculations that are used 675 // by default. 676 GraphDef converted_graph_def; 677 TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context, 678 &converted_graph_def)); 679 TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def)); 680 681 // If fallback_min and fallback_max are set, then we'll use hardwired ranges 682 // for all the 32-bit to 8-bit requantizations. 683 float fallback_min; 684 float fallback_max; 685 bool has_fallback_range; 686 TF_RETURN_IF_ERROR(ExtractRangeFromParams( 687 context, "fallback_min", "fallback_max", &fallback_min, &fallback_max, 688 &has_fallback_range)); 689 690 // Replace all occurrences of the current float op with its quantized 691 // equivalent. 692 GraphDef quantized_graph_def; 693 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 694 converted_graph_def, {op_pattern}, 695 [&op_map, fallback_min, fallback_max, has_fallback_range]( 696 const NodeMatch& match, const std::set<string>& input_nodes, 697 const std::set<string>& output_nodes, 698 std::vector<NodeDef>* new_nodes) { 699 const NodeDef& float_node = match.node; 700 const QuantizedOpInfo& op_info = op_map[float_node.op()]; 701 702 DataTypeVector input_types; 703 DataTypeVector output_types; 704 TF_RETURN_IF_ERROR( 705 GetInOutTypes(float_node, &input_types, &output_types)); 706 bool are_all_float = true; 707 for (int i = 0; i < float_node.input_size(); ++i) { 708 // Skip any known non-float inputs. 709 if (op_info.unquantized_inputs.count(i)) { 710 continue; 711 } 712 if (input_types[i] != DT_FLOAT) { 713 are_all_float = false; 714 } 715 } 716 for (const DataType& output_type : output_types) { 717 if (output_type != DT_FLOAT) { 718 are_all_float = false; 719 } 720 } 721 // This isn't a float op, so don't quantize it. 722 if (!are_all_float) { 723 CopyOriginalMatch(match, new_nodes); 724 return Status::OK(); 725 } 726 727 string namespace_prefix = float_node.name() + "_eightbit"; 728 729 // Quantize all of the inputs. 730 std::vector<string> quantized_input_names; 731 for (int i = 0; i < float_node.input_size(); ++i) { 732 // Skip any non-float inputs. 733 if (op_info.unquantized_inputs.count(i)) { 734 continue; 735 } 736 737 const string& input_name = float_node.input(i); 738 string unique_input_name = 739 namespace_prefix + "/" + UniqueNodeNameFromInput(input_name); 740 741 // Add some common constants we need for reshaping inputs. 742 NodeDef reshape_dims; 743 reshape_dims.set_op("Const"); 744 reshape_dims.set_name(unique_input_name + "/reshape_dims"); 745 AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims); 746 SetNodeAttr("dtype", DT_INT32, &reshape_dims); 747 Tensor reshape_dims_tensor(DT_INT32, {1}); 748 reshape_dims_tensor.flat<int32>()(0) = -1; 749 SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims); 750 new_nodes->push_back(reshape_dims); 751 752 NodeDef reduction_dims; 753 reduction_dims.set_op("Const"); 754 reduction_dims.set_name(unique_input_name + "/reduction_dims"); 755 AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims); 756 SetNodeAttr("dtype", DT_INT32, &reduction_dims); 757 Tensor reduction_dims_tensor(DT_INT32, {1}); 758 reduction_dims_tensor.flat<int32>()(0) = 0; 759 SetNodeTensorAttr<int32>("value", reduction_dims_tensor, 760 &reduction_dims); 761 new_nodes->push_back(reduction_dims); 762 763 NodeDef reshape_node; 764 reshape_node.set_op("Reshape"); 765 reshape_node.set_name(unique_input_name + "/reshape"); 766 SetNodeAttr("T", DT_FLOAT, &reshape_node); 767 AddNodeInput(input_name, &reshape_node); 768 AddNodeInput(reshape_dims.name(), &reshape_node); 769 new_nodes->push_back(reshape_node); 770 771 NodeDef min_node; 772 min_node.set_op("Min"); 773 min_node.set_name(unique_input_name + "/min"); 774 SetNodeAttr("T", DT_FLOAT, &min_node); 775 SetNodeAttr("keep_dims", false, &min_node); 776 AddNodeInput(reshape_node.name(), &min_node); 777 AddNodeInput(reduction_dims.name(), &min_node); 778 new_nodes->push_back(min_node); 779 780 NodeDef max_node; 781 max_node.set_op("Max"); 782 max_node.set_name(unique_input_name + "/max"); 783 SetNodeAttr("T", DT_FLOAT, &max_node); 784 SetNodeAttr("keep_dims", false, &max_node); 785 AddNodeInput(reshape_node.name(), &max_node); 786 AddNodeInput(reduction_dims.name(), &max_node); 787 new_nodes->push_back(max_node); 788 789 NodeDef quantize_node; 790 quantize_node.set_op("QuantizeV2"); 791 quantize_node.set_name(unique_input_name + "/quantize"); 792 SetNodeAttr("T", DT_QUINT8, &quantize_node); 793 SetNodeAttr("mode", "MIN_FIRST", &quantize_node); 794 AddNodeInput(input_name, &quantize_node); 795 AddNodeInput(min_node.name(), &quantize_node); 796 AddNodeInput(max_node.name(), &quantize_node); 797 new_nodes->push_back(quantize_node); 798 quantized_input_names.push_back(quantize_node.name()); 799 } 800 801 // Set up the quantized version of the current op. 802 NodeDef quantized_main_node; 803 quantized_main_node.set_op("Quantized" + float_node.op()); 804 quantized_main_node.set_name(float_node.name() + "/eightbit"); 805 for (const string& attr_to_copy : op_info.attrs_to_copy) { 806 CopyNodeAttr(float_node, attr_to_copy, attr_to_copy, 807 &quantized_main_node); 808 } 809 for (const std::pair<string, DataType>& dtype_to_set : 810 op_info.dtypes_to_set) { 811 SetNodeAttr(dtype_to_set.first, dtype_to_set.second, 812 &quantized_main_node); 813 } 814 int quantized_input_index = 0; 815 for (int i = 0; i < float_node.input_size(); ++i) { 816 if (op_info.unquantized_inputs.count(i)) { 817 AddNodeInput(float_node.input(i), &quantized_main_node); 818 } else { 819 const string& quantized_input_name = 820 quantized_input_names[quantized_input_index]; 821 AddNodeInput(quantized_input_name + ":0", &quantized_main_node); 822 ++quantized_input_index; 823 } 824 } 825 if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) { 826 for (const string& quantized_input_name : quantized_input_names) { 827 AddNodeInput(quantized_input_name + ":1", &quantized_main_node); 828 AddNodeInput(quantized_input_name + ":2", &quantized_main_node); 829 } 830 } else { 831 for (const string& quantized_input_name : quantized_input_names) { 832 AddNodeInput(quantized_input_name + ":1", &quantized_main_node); 833 } 834 for (const string& quantized_input_name : quantized_input_names) { 835 AddNodeInput(quantized_input_name + ":2", &quantized_main_node); 836 } 837 } 838 new_nodes->push_back(quantized_main_node); 839 840 string eight_bit_node_name; 841 if (op_info.output_bit_depth == DT_QINT32) { 842 // Shrink the range of the output down from 32 bits to 8. 843 string requantize_min_input; 844 string requantize_max_input; 845 if (has_fallback_range) { 846 // Use constant values for the min/max range if they were given. 847 NodeDef fallback_min_node; 848 fallback_min_node.set_op("Const"); 849 fallback_min_node.set_name(quantized_main_node.name() + 850 "/fallback_min"); 851 SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node); 852 Tensor fallback_min_tensor(DT_FLOAT, {}); 853 fallback_min_tensor.flat<float>()(0) = fallback_min; 854 SetNodeTensorAttr<float>("value", fallback_min_tensor, 855 &fallback_min_node); 856 new_nodes->push_back(fallback_min_node); 857 858 NodeDef fallback_max_node; 859 fallback_max_node.set_op("Const"); 860 fallback_max_node.set_name(quantized_main_node.name() + 861 "/fallback_max"); 862 SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node); 863 Tensor fallback_max_tensor(DT_FLOAT, {}); 864 fallback_max_tensor.flat<float>()(0) = fallback_max; 865 SetNodeTensorAttr<float>("value", fallback_max_tensor, 866 &fallback_max_node); 867 new_nodes->push_back(fallback_max_node); 868 869 requantize_min_input = fallback_min_node.name(); 870 requantize_max_input = fallback_max_node.name(); 871 } else { 872 // Otherwise dynamically measure the range each time. 873 NodeDef requant_range_node; 874 requant_range_node.set_op("RequantizationRange"); 875 requant_range_node.set_name(quantized_main_node.name() + 876 "/requant_range"); 877 SetNodeAttr("Tinput", DT_QINT32, &requant_range_node); 878 AddNodeInput(quantized_main_node.name() + ":0", 879 &requant_range_node); 880 AddNodeInput(quantized_main_node.name() + ":1", 881 &requant_range_node); 882 AddNodeInput(quantized_main_node.name() + ":2", 883 &requant_range_node); 884 new_nodes->push_back(requant_range_node); 885 886 requantize_min_input = requant_range_node.name() + ":0"; 887 requantize_max_input = requant_range_node.name() + ":1"; 888 } 889 NodeDef requantize_node; 890 requantize_node.set_op("Requantize"); 891 requantize_node.set_name(quantized_main_node.name() + "/requantize"); 892 SetNodeAttr("Tinput", DT_QINT32, &requantize_node); 893 SetNodeAttr("out_type", DT_QUINT8, &requantize_node); 894 AddNodeInput(quantized_main_node.name() + ":0", &requantize_node); 895 AddNodeInput(quantized_main_node.name() + ":1", &requantize_node); 896 AddNodeInput(quantized_main_node.name() + ":2", &requantize_node); 897 AddNodeInput(requantize_min_input, &requantize_node); 898 AddNodeInput(requantize_max_input, &requantize_node); 899 new_nodes->push_back(requantize_node); 900 eight_bit_node_name = requantize_node.name(); 901 } else { 902 eight_bit_node_name = quantized_main_node.name(); 903 } 904 905 // Convert the 8-bit result back into float for the final output. 906 NodeDef dequantize_node; 907 dequantize_node.set_op("Dequantize"); 908 dequantize_node.set_name(float_node.name()); 909 SetNodeAttr("T", DT_QUINT8, &dequantize_node); 910 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); 911 AddNodeInput(eight_bit_node_name + ":0", &dequantize_node); 912 AddNodeInput(eight_bit_node_name + ":1", &dequantize_node); 913 AddNodeInput(eight_bit_node_name + ":2", &dequantize_node); 914 new_nodes->push_back(dequantize_node); 915 916 return Status::OK(); 917 }, 918 {}, &quantized_graph_def)); 919 TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def)); 920 921 // If we've ended up with two Requantize ops in a row (for example if there 922 // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together, 923 // using the trained range from the second op. 924 GraphDef merged_graph_def; 925 TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context, 926 &merged_graph_def)); 927 TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def)); 928 929 // There can be duplicate quantize nodes if multiple ops pull from a single 930 // input, which makes it harder to remove redundant ones, so strip them out. 931 GraphDef deduped_graph_def; 932 TF_RETURN_IF_ERROR( 933 MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def)); 934 TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def)); 935 936 // Look for Dequantizes that immediately go into Quantizes, and remove them 937 // since the two together cancel each other out. This allows us to keep the 938 // data flow in eight bit where two adjacent ops are in eight bit, but still 939 // keep interoperability with float ops. 940 TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context, 941 output_graph_def)); 942 TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def)); 943 944 return Status::OK(); 945 } 946 947 REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes); 948 949 REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes); 950 951 } // namespace graph_transforms 952 } // namespace tensorflow 953