1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/graph_transforms/transform_utils.h" 17 18 #include "tensorflow/core/framework/node_def_util.h" 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/lib/hash/hash.h" 21 #include "tensorflow/core/lib/strings/str_util.h" 22 23 namespace tensorflow { 24 namespace graph_transforms { 25 26 namespace { 27 inline bool IsMerge(const NodeDef& node_def) { 28 return node_def.op() == "Merge" || node_def.op() == "RefMerge"; 29 } 30 31 void RecordMatchedNodes(const NodeMatch& match, 32 std::set<string>* matched_nodes) { 33 matched_nodes->insert(match.node.name()); 34 for (const NodeMatch& input_match : match.inputs) { 35 RecordMatchedNodes(input_match, matched_nodes); 36 } 37 } 38 39 inline uint64 Hash64String(const string& input) { 40 return Hash64(input.data(), input.size()); 41 } 42 } // namespace 43 44 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) { 45 std::set<string> found_nodes; 46 std::vector<NodeMatch> current_matches = {match}; 47 while (!current_matches.empty()) { 48 std::vector<NodeMatch> next_matches; 49 for (const NodeMatch& current_match : current_matches) { 50 if (found_nodes.count(current_match.node.name())) { 51 continue; 52 } 53 found_nodes.insert(current_match.node.name()); 54 result->push_back(current_match.node); 55 for (const NodeMatch& input_match : current_match.inputs) { 56 next_matches.push_back(input_match); 57 } 58 } 59 current_matches = next_matches; 60 } 61 } 62 63 void MapNamesToNodes(const GraphDef& graph_def, 64 std::map<string, const NodeDef*>* result) { 65 for (const NodeDef& node : graph_def.node()) { 66 (*result)[node.name()] = &node; 67 } 68 } 69 70 void MapNodesToOutputs(const GraphDef& graph_def, 71 std::map<string, std::vector<const NodeDef*>>* result) { 72 std::map<string, const NodeDef*> node_map; 73 MapNamesToNodes(graph_def, &node_map); 74 for (const NodeDef& node : graph_def.node()) { 75 for (const string& input : node.input()) { 76 string input_node_name = NodeNameFromInput(input); 77 (*result)[input_node_name].push_back(&node); 78 } 79 } 80 } 81 82 void NodeNamePartsFromInput(const string& input_name, string* prefix, 83 string* node_name, string* suffix) { 84 std::vector<string> input_parts = str_util::Split(input_name, ':'); 85 if (input_parts.size() < 2) { 86 *suffix = ""; 87 } else { 88 *suffix = ":" + input_parts[1]; 89 } 90 StringPiece node_name_piece(input_parts[0]); 91 if (node_name_piece.Consume("^")) { 92 *prefix = "^"; 93 } else { 94 *prefix = ""; 95 } 96 *node_name = node_name_piece.ToString(); 97 } 98 99 string NodeNameFromInput(const string& input_name) { 100 string prefix; 101 string node_name; 102 string suffix; 103 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); 104 return node_name; 105 } 106 107 string CanonicalInputName(const string& input_name) { 108 string prefix; 109 string node_name; 110 string suffix; 111 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); 112 if (suffix.empty()) { 113 suffix = ":0"; 114 } 115 return prefix + node_name + suffix; 116 } 117 118 uint64 HashNodeDef(const NodeDef& node) { 119 uint64 hash = Hash64String(node.op()); 120 hash = Hash64Combine(hash, Hash64String(node.name())); 121 for (const string& input : node.input()) { 122 hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input))); 123 } 124 hash = Hash64Combine(hash, Hash64String(node.device())); 125 std::vector<string> attr_names; 126 attr_names.reserve(node.attr().size()); 127 for (const auto& attr : node.attr()) { 128 attr_names.push_back(attr.first); 129 } 130 std::sort(attr_names.begin(), attr_names.end()); 131 string attr_serialized; 132 for (const string& attr_name : attr_names) { 133 auto attr = node.attr().at(attr_name); 134 attr.SerializeToString(&attr_serialized); 135 hash = Hash64Combine(hash, Hash64String(attr_serialized)); 136 } 137 return hash; 138 } 139 140 void AddNodeInput(const string& input_name, NodeDef* node) { 141 *(node->mutable_input()->Add()) = input_name; 142 } 143 144 void CopyNodeAttr(const NodeDef& source, const string& source_key, 145 const string& dest_key, NodeDef* dest) { 146 CHECK_NE(0, source.attr().count(source_key)) 147 << "No key '" << source_key << "' found in " << source.DebugString(); 148 (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key); 149 } 150 151 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) { 152 TensorProto tensor_proto = node.attr().at(key).tensor(); 153 Tensor tensor; 154 CHECK(tensor.FromProto(tensor_proto)); 155 return tensor; 156 } 157 158 void FilterGraphDef(const GraphDef& input_graph_def, 159 std::function<bool(const NodeDef&)> selector, 160 GraphDef* output_graph_def) { 161 output_graph_def->mutable_node()->Clear(); 162 for (const NodeDef& node : input_graph_def.node()) { 163 if (selector(node)) { 164 *output_graph_def->mutable_node()->Add() = node; 165 } 166 } 167 } 168 169 void RemoveAttributes(const GraphDef& input_graph_def, 170 const std::vector<string>& attributes, 171 GraphDef* output_graph_def) { 172 output_graph_def->mutable_node()->Clear(); 173 for (const NodeDef& node : input_graph_def.node()) { 174 NodeDef* new_node = output_graph_def->mutable_node()->Add(); 175 *new_node = node; 176 for (const string& attribute : attributes) { 177 new_node->mutable_attr()->erase(attribute); 178 } 179 } 180 } 181 182 Status SortByExecutionOrder(const GraphDef& input_graph_def, 183 GraphDef* output_graph_def) { 184 const int num_nodes = input_graph_def.node_size(); 185 std::vector<int> ready; 186 std::vector<int> pending_count; 187 pending_count.reserve(num_nodes); 188 std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes); 189 190 std::map<string, int> name_index; 191 for (int i = 0; i < input_graph_def.node_size(); ++i) { 192 const NodeDef& node(input_graph_def.node(i)); 193 name_index[node.name()] = i; 194 } 195 196 // Parse the inputs for each node. 197 for (int n = 0; n < num_nodes; ++n) { 198 const NodeDef& node_def(input_graph_def.node(n)); 199 if (IsMerge(node_def)) { 200 // for merge only wait for one non-control input. 201 int32 num_control_edges = 0; 202 for (int i = 0; i < node_def.input_size(); ++i) { 203 StringPiece input_name(node_def.input(i)); 204 if (input_name.starts_with("^")) { 205 num_control_edges++; 206 } 207 } 208 pending_count.push_back(num_control_edges + 1); 209 } else { 210 pending_count.push_back(node_def.input_size()); 211 } 212 if (node_def.input_size() == 0) { 213 ready.push_back(n); 214 continue; 215 } 216 for (int i = 0; i < node_def.input_size(); ++i) { 217 const string& input_name = node_def.input(i); 218 const string& input_node_name = NodeNameFromInput(input_name); 219 if (!name_index.count(input_node_name)) { 220 return errors::InvalidArgument("Node '", node_def.name(), 221 "': Unknown input node '", 222 node_def.input(i), "'"); 223 } 224 outputs[name_index[input_node_name]].push_back(n); 225 } 226 } 227 228 int processed = 0; 229 output_graph_def->Clear(); 230 // Process the NodeDefs in topological order. 231 // Code above sets this up by filling in ready_ with nodes that have no 232 // inputs, pending_counts_ with the number of inputs for each node and 233 // outputs_ with the outputs of each node. 234 while (!ready.empty()) { 235 int o = ready.back(); 236 ready.pop_back(); 237 ++processed; 238 const NodeDef& node_def(input_graph_def.node(o)); 239 *output_graph_def->mutable_node()->Add() = node_def; 240 241 // Update pending_count for outputs. 242 for (size_t i = 0; i < outputs[o].size(); ++i) { 243 const int output = outputs[o][i]; 244 pending_count[output]--; 245 if (pending_count[output] == 0) { 246 ready.push_back(output); 247 } 248 } 249 } 250 251 if (processed < input_graph_def.node_size()) { 252 return errors::InvalidArgument(input_graph_def.node_size() - processed, 253 " nodes in a cycle"); 254 } 255 return Status::OK(); 256 } 257 258 string OpTypePattern::DebugString() const { 259 string result = "{" + op + ", {"; 260 for (const OpTypePattern& input : inputs) { 261 result += input.DebugString() + ","; 262 } 263 result += "}}"; 264 return result; 265 } 266 267 string NodeMatch::DebugString() const { 268 string result = "{"; 269 result += node.DebugString(); 270 result += ", {"; 271 for (const NodeMatch& input : inputs) { 272 result += input.DebugString() + ","; 273 } 274 result += "}}"; 275 return result; 276 } 277 278 GraphMatcher::GraphMatcher(const GraphDef& graph_def) { 279 SortByExecutionOrder(graph_def, &graph_def_).IgnoreError(); 280 MapNamesToNodes(graph_def_, &node_map_); 281 } 282 283 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, 284 std::vector<NodeMatch>* matches) { 285 std::set<string> matched_nodes; 286 for (const NodeDef& node : graph_def_.node()) { 287 // Skip any nodes that are already part of a match. 288 if (matched_nodes.count(node.name())) { 289 continue; 290 } 291 NodeMatch match; 292 if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) { 293 RecordMatchedNodes(match, &matched_nodes); 294 matches->push_back(match); 295 } 296 } 297 return Status::OK(); 298 } 299 300 bool GraphMatcher::DoesOpTypeMatch( 301 const NodeDef& node, const OpTypePattern& pattern, 302 const std::set<string>& previously_matched_nodes, NodeMatch* match) { 303 VLOG(1) << "Looking at node " << node.DebugString(); 304 VLOG(1) << "pattern=" << pattern.DebugString(); 305 VLOG(1) << "match=" << match->DebugString(); 306 if (previously_matched_nodes.count(node.name())) { 307 VLOG(1) << "node " << node.name() << " has been previously matched"; 308 return false; 309 } 310 bool pattern_matched = false; 311 if (pattern.op == "*") { 312 pattern_matched = true; 313 } else { 314 std::vector<string> pattern_ops = str_util::Split(pattern.op, '|'); 315 for (const string& pattern_op : pattern_ops) { 316 if (node.op() == pattern_op) { 317 pattern_matched = true; 318 } 319 } 320 } 321 if (!pattern_matched) { 322 VLOG(1) << "node.op() != pattern.op()"; 323 return false; 324 } 325 match->node = node; 326 // Ignore any control inputs for pattern-matching purposes 327 std::vector<string> non_control_inputs; 328 for (const string& input : node.input()) { 329 if (!input.empty() && (input[0] != '^')) { 330 non_control_inputs.push_back(input); 331 } 332 } 333 if (pattern.inputs.empty()) { 334 // If there are no inputs, assume that's the end of the pattern. 335 return true; 336 } 337 if (non_control_inputs.size() != pattern.inputs.size()) { 338 VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()"; 339 return false; 340 } 341 for (int i = 0; i < pattern.inputs.size(); ++i) { 342 const string& input_node_name = NodeNameFromInput(non_control_inputs[i]); 343 const NodeDef& input_node = *(node_map_[input_node_name]); 344 const OpTypePattern& input_pattern = pattern.inputs[i]; 345 match->inputs.push_back(NodeMatch()); 346 NodeMatch* input_match = &(match->inputs.back()); 347 if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes, 348 input_match)) { 349 return false; 350 } 351 } 352 return true; 353 } 354 355 Status ReplaceMatchingOpTypes( 356 const GraphDef& input_graph_def, const OpTypePattern& pattern, 357 const std::function<Status(const NodeMatch&, const std::set<string>&, 358 const std::set<string>&, std::vector<NodeDef>*)>& 359 node_generator, 360 const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) { 361 // Start off by retrieving all the matching subgraphs. 362 GraphMatcher matcher(input_graph_def); 363 std::vector<NodeMatch> matches; 364 TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches)); 365 366 // Do some housekeeping so we can easily look up the resulting matches given 367 // a node name. 368 std::set<string> matched_nodes; 369 std::map<string, const NodeMatch*> matches_by_head_name; 370 for (const NodeMatch& match : matches) { 371 matches_by_head_name[match.node.name()] = &match; 372 RecordMatchedNodes(match, &matched_nodes); 373 } 374 std::map<string, std::vector<const NodeDef*>> outputs_map; 375 MapNodesToOutputs(input_graph_def, &outputs_map); 376 377 // Go through all the nodes in the input graph, see if they are part of a 378 // match or if they can be left untouched. 379 output_graph_def->Clear(); 380 for (const NodeDef& input_node : input_graph_def.node()) { 381 if (matches_by_head_name.count(input_node.name())) { 382 // This node is the beginning of a match, so call the replacement function 383 // after setting up some information it will need. 384 const NodeMatch* match = matches_by_head_name[input_node.name()]; 385 std::vector<NodeDef> matched_nodes_array; 386 MatchedNodesAsArray(*match, &matched_nodes_array); 387 // This tells us whether a node is part of the current match. 388 std::set<string> matched_nodes_lookup; 389 for (const NodeDef& matched_node : matched_nodes_array) { 390 matched_nodes_lookup.insert(matched_node.name()); 391 } 392 // These are helper arrays that the replacement function can use to tell 393 // whether it can safely remove an internal node (because nothing outside 394 // of the match uses it) or whether external nodes depend on it. 395 std::set<string> input_nodes; 396 std::set<string> output_nodes; 397 for (const NodeDef& matched_node : matched_nodes_array) { 398 // Look through all of this node's inputs, and if any of them come from 399 // outside the match, then this should be noted as one of the external 400 // inputs of the subgraph. 401 for (const string& input_name : matched_node.input()) { 402 string input_node_name = NodeNameFromInput(input_name); 403 if (!matched_nodes_lookup.count(input_node_name)) { 404 input_nodes.insert(matched_node.name()); 405 } 406 } 407 // Do a reverse input lookup, to see which other nodes use the current 408 // one as an input. If any of those nodes are outside the match 409 // subgraph, then the current node is marked as an output node that 410 // shouldn't be removed. 411 if (outputs_map.count(matched_node.name())) { 412 for (const NodeDef* dependent_node : 413 outputs_map[matched_node.name()]) { 414 if (!matched_nodes_lookup.count(dependent_node->name())) { 415 output_nodes.insert(matched_node.name()); 416 } 417 } 418 } 419 } 420 // Call the generator function and add all the returned nodes to the 421 // graph. 422 std::vector<NodeDef> new_nodes; 423 TF_RETURN_IF_ERROR( 424 node_generator(*match, input_nodes, output_nodes, &new_nodes)); 425 std::set<string> new_node_names; 426 for (const NodeDef& new_node : new_nodes) { 427 new_node_names.insert(new_node.name()); 428 } 429 // Check to make sure the generator function preserved all of the nodes 430 // that are used elsewhere in the graph, and add them back in if not. 431 bool abort_replacement = false; 432 if (!options.allow_inconsistencies) { 433 for (const string& expected_output : output_nodes) { 434 if (!new_node_names.count(expected_output)) { 435 LOG(WARNING) << "Expected " << expected_output 436 << " to be preserved."; 437 abort_replacement = true; 438 } 439 } 440 } 441 if (abort_replacement) { 442 LOG(WARNING) << "Generator function didn't preserve needed nodes, " 443 << "copying old replacements back in instead."; 444 std::vector<NodeDef> old_nodes; 445 MatchedNodesAsArray(*match, &old_nodes); 446 for (const NodeDef& old_node : old_nodes) { 447 NodeDef* added_node = output_graph_def->mutable_node()->Add(); 448 *added_node = old_node; 449 } 450 } else { 451 for (const NodeDef& new_node : new_nodes) { 452 NodeDef* added_node = output_graph_def->mutable_node()->Add(); 453 *added_node = new_node; 454 } 455 } 456 } else if (!matched_nodes.count(input_node.name())) { 457 // This node isn't part of any match, so just copy it over. 458 NodeDef* added_node = output_graph_def->mutable_node()->Add(); 459 *added_node = input_node; 460 } else { 461 // Do nothing, because this is an internal part of a matching subgraph, 462 // and so will have been replaced by a new replacement subgraph. 463 } 464 } 465 466 return Status::OK(); 467 } 468 469 Status RenameNodeInputs(const GraphDef& input_graph_def, 470 const std::map<string, string>& inputs_to_rename, 471 const std::unordered_set<string>& nodes_to_ignore, 472 GraphDef* output_graph_def) { 473 std::map<string, std::vector<std::pair<string, string>>> 474 canonical_inputs_to_rename; 475 for (const auto& input_to_rename : inputs_to_rename) { 476 canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)] 477 .push_back({input_to_rename.first, input_to_rename.second}); 478 } 479 480 output_graph_def->Clear(); 481 for (const NodeDef& node : input_graph_def.node()) { 482 NodeDef* new_node = output_graph_def->mutable_node()->Add(); 483 *new_node = node; 484 new_node->mutable_input()->Clear(); 485 for (const string& input_name : node.input()) { 486 std::set<string> already_visited; 487 string new_input_name = input_name; 488 while ( 489 canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) { 490 string input_node_name = NodeNameFromInput(new_input_name); 491 if (already_visited.count(input_node_name)) { 492 return errors::InvalidArgument( 493 "RenameNodeInputs argument contains a cycle for ", 494 input_node_name); 495 } 496 already_visited.insert(input_node_name); 497 if (nodes_to_ignore.count(node.name())) { 498 break; 499 } 500 bool any_match_found = false; 501 for (const std::pair<string, string>& input_to_rename : 502 canonical_inputs_to_rename.at(input_node_name)) { 503 const string& source_name = input_to_rename.first; 504 const string& dest_name = input_to_rename.second; 505 bool is_match; 506 string match_name; 507 if (StringPiece(source_name).ends_with(":*")) { 508 is_match = true; 509 string prefix; 510 string unused_node_name; 511 string suffix; 512 NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name, 513 &suffix); 514 match_name = prefix + dest_name + suffix; 515 } else { 516 is_match = (CanonicalInputName(source_name) == 517 CanonicalInputName(new_input_name)); 518 match_name = dest_name; 519 } 520 if (is_match) { 521 new_input_name = match_name; 522 any_match_found = true; 523 } 524 } 525 if (!any_match_found) { 526 break; 527 } 528 } 529 *(new_node->mutable_input()->Add()) = new_input_name; 530 } 531 } 532 return Status::OK(); 533 } 534 535 void CopyOriginalMatch(const NodeMatch& match, 536 std::vector<NodeDef>* new_nodes) { 537 std::vector<NodeDef> old_nodes; 538 MatchedNodesAsArray(match, &old_nodes); 539 for (const NodeDef& old_node : old_nodes) { 540 new_nodes->push_back(old_node); 541 } 542 } 543 544 TransformRegistry* GetTransformRegistry() { 545 static TransformRegistry transform_registry; 546 return &transform_registry; 547 } 548 549 void FindInvalidInputs(const GraphDef& graph_def, 550 std::vector<std::pair<string, string>>* invalid_inputs) { 551 std::map<string, const NodeDef*> node_map; 552 MapNamesToNodes(graph_def, &node_map); 553 554 for (const NodeDef& node : graph_def.node()) { 555 for (const string& input : node.input()) { 556 string input_node = NodeNameFromInput(input); 557 if (!node_map.count(input_node)) { 558 invalid_inputs->push_back({node.name(), input_node}); 559 } 560 } 561 } 562 } 563 564 Status IsGraphValid(const GraphDef& graph_def) { 565 std::vector<std::pair<string, string>> invalid_inputs; 566 FindInvalidInputs(graph_def, &invalid_inputs); 567 if (!invalid_inputs.empty()) { 568 std::map<string, const NodeDef*> node_map; 569 MapNamesToNodes(graph_def, &node_map); 570 for (const std::pair<string, string>& invalid_input : invalid_inputs) { 571 LOG(ERROR) << "Invalid input " << invalid_input.second << " for node " 572 << invalid_input.first << " - " 573 << node_map[invalid_input.first]->DebugString(); 574 } 575 return errors::Internal( 576 "Invalid graph with inputs referring to nonexistent nodes"); 577 } 578 return Status::OK(); 579 } 580 581 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, 582 DataTypeVector* outputs) { 583 const OpDef* op_def; 584 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); 585 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs)); 586 return Status::OK(); 587 } 588 589 Status TensorShapeFromString(const string& shape_string, TensorShape* result) { 590 if (shape_string.empty()) { 591 return errors::InvalidArgument("Specificed shape is empty."); 592 } 593 std::vector<int64> dims; 594 if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) { 595 return errors::InvalidArgument("Could parse as shape: '", shape_string, 596 "'"); 597 } 598 *result = TensorShape(dims); 599 return Status::OK(); 600 } 601 602 int TransformFuncContext::CountParameters(const string& name) const { 603 if (params.count(name)) { 604 return params.at(name).size(); 605 } else { 606 return 0; 607 } 608 } 609 610 Status TransformFuncContext::GetOneStringParameter(const string& name, 611 const string& default_value, 612 string* result) const { 613 const int params_count = CountParameters(name); 614 if (params_count == 0) { 615 *result = default_value; 616 return Status::OK(); 617 } else if (params_count == 1) { 618 *result = params.at(name).at(0); 619 return Status::OK(); 620 } else { 621 return errors::InvalidArgument("Expected a single '", name, 622 "' parameter, but found ", params_count, 623 " occurrences"); 624 } 625 } 626 627 Status TransformFuncContext::GetOneInt32Parameter(const string& name, 628 int32 default_value, 629 int32* result) const { 630 const int params_count = CountParameters(name); 631 if (params_count == 0) { 632 *result = default_value; 633 return Status::OK(); 634 } 635 string string_value; 636 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); 637 if (!strings::safe_strto32(StringPiece(string_value), result)) { 638 return errors::InvalidArgument("Couldn't interpret the ", name, 639 " argument as a number:", string_value); 640 } 641 return Status::OK(); 642 } 643 644 Status TransformFuncContext::GetOneInt64Parameter(const string& name, 645 int64 default_value, 646 int64* result) const { 647 const int params_count = CountParameters(name); 648 if (params_count == 0) { 649 *result = default_value; 650 return Status::OK(); 651 } 652 string string_value; 653 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); 654 if (!strings::safe_strto64(StringPiece(string_value), result)) { 655 return errors::InvalidArgument("Couldn't interpret the ", name, 656 " argument as a number:", string_value); 657 } 658 return Status::OK(); 659 } 660 661 Status TransformFuncContext::GetOneFloatParameter(const string& name, 662 float default_value, 663 float* result) const { 664 const int params_count = CountParameters(name); 665 if (params_count == 0) { 666 *result = default_value; 667 return Status::OK(); 668 } 669 string string_value; 670 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); 671 if (!strings::safe_strtof(string_value.c_str(), result)) { 672 return errors::InvalidArgument( 673 "Couldn't interpret the ", name, 674 " argument as a float number:", string_value); 675 } 676 return Status::OK(); 677 } 678 679 Status TransformFuncContext::GetOneBoolParameter(const string& name, 680 bool default_value, 681 bool* result) const { 682 const int params_count = CountParameters(name); 683 if (params_count == 0) { 684 *result = default_value; 685 return Status::OK(); 686 } 687 string string_value; 688 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); 689 if (string_value == "true" || string_value == "1") { 690 *result = true; 691 } else if (string_value == "false" || string_value == "0") { 692 *result = false; 693 } else { 694 return errors::InvalidArgument("Couldn't interpret the ", name, 695 " argument as a boolean:", string_value, 696 " (expected true, false, 0 or 1)"); 697 } 698 return Status::OK(); 699 } 700 701 } // namespace graph_transforms 702 } // namespace tensorflow 703