Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     17 
     18 #include "tensorflow/core/framework/node_def_util.h"
     19 #include "tensorflow/core/framework/op.h"
     20 #include "tensorflow/core/lib/hash/hash.h"
     21 #include "tensorflow/core/lib/strings/str_util.h"
     22 
     23 namespace tensorflow {
     24 namespace graph_transforms {
     25 
     26 namespace {
     27 inline bool IsMerge(const NodeDef& node_def) {
     28   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
     29 }
     30 
     31 void RecordMatchedNodes(const NodeMatch& match,
     32                         std::set<string>* matched_nodes) {
     33   matched_nodes->insert(match.node.name());
     34   for (const NodeMatch& input_match : match.inputs) {
     35     RecordMatchedNodes(input_match, matched_nodes);
     36   }
     37 }
     38 
     39 inline uint64 Hash64String(const string& input) {
     40   return Hash64(input.data(), input.size());
     41 }
     42 }  // namespace
     43 
     44 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
     45   std::set<string> found_nodes;
     46   std::vector<NodeMatch> current_matches = {match};
     47   while (!current_matches.empty()) {
     48     std::vector<NodeMatch> next_matches;
     49     for (const NodeMatch& current_match : current_matches) {
     50       if (found_nodes.count(current_match.node.name())) {
     51         continue;
     52       }
     53       found_nodes.insert(current_match.node.name());
     54       result->push_back(current_match.node);
     55       for (const NodeMatch& input_match : current_match.inputs) {
     56         next_matches.push_back(input_match);
     57       }
     58     }
     59     current_matches = next_matches;
     60   }
     61 }
     62 
     63 void MapNamesToNodes(const GraphDef& graph_def,
     64                      std::map<string, const NodeDef*>* result) {
     65   for (const NodeDef& node : graph_def.node()) {
     66     (*result)[node.name()] = &node;
     67   }
     68 }
     69 
     70 void MapNodesToOutputs(const GraphDef& graph_def,
     71                        std::map<string, std::vector<const NodeDef*>>* result) {
     72   std::map<string, const NodeDef*> node_map;
     73   MapNamesToNodes(graph_def, &node_map);
     74   for (const NodeDef& node : graph_def.node()) {
     75     for (const string& input : node.input()) {
     76       string input_node_name = NodeNameFromInput(input);
     77       (*result)[input_node_name].push_back(&node);
     78     }
     79   }
     80 }
     81 
     82 void NodeNamePartsFromInput(const string& input_name, string* prefix,
     83                             string* node_name, string* suffix) {
     84   std::vector<string> input_parts = str_util::Split(input_name, ':');
     85   if (input_parts.size() < 2) {
     86     *suffix = "";
     87   } else {
     88     *suffix = ":" + input_parts[1];
     89   }
     90   StringPiece node_name_piece(input_parts[0]);
     91   if (node_name_piece.Consume("^")) {
     92     *prefix = "^";
     93   } else {
     94     *prefix = "";
     95   }
     96   *node_name = node_name_piece.ToString();
     97 }
     98 
     99 string NodeNameFromInput(const string& input_name) {
    100   string prefix;
    101   string node_name;
    102   string suffix;
    103   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
    104   return node_name;
    105 }
    106 
    107 string CanonicalInputName(const string& input_name) {
    108   string prefix;
    109   string node_name;
    110   string suffix;
    111   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
    112   if (suffix.empty()) {
    113     suffix = ":0";
    114   }
    115   return prefix + node_name + suffix;
    116 }
    117 
    118 uint64 HashNodeDef(const NodeDef& node) {
    119   uint64 hash = Hash64String(node.op());
    120   hash = Hash64Combine(hash, Hash64String(node.name()));
    121   for (const string& input : node.input()) {
    122     hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
    123   }
    124   hash = Hash64Combine(hash, Hash64String(node.device()));
    125   std::vector<string> attr_names;
    126   attr_names.reserve(node.attr().size());
    127   for (const auto& attr : node.attr()) {
    128     attr_names.push_back(attr.first);
    129   }
    130   std::sort(attr_names.begin(), attr_names.end());
    131   string attr_serialized;
    132   for (const string& attr_name : attr_names) {
    133     auto attr = node.attr().at(attr_name);
    134     attr.SerializeToString(&attr_serialized);
    135     hash = Hash64Combine(hash, Hash64String(attr_serialized));
    136   }
    137   return hash;
    138 }
    139 
    140 void AddNodeInput(const string& input_name, NodeDef* node) {
    141   *(node->mutable_input()->Add()) = input_name;
    142 }
    143 
    144 void CopyNodeAttr(const NodeDef& source, const string& source_key,
    145                   const string& dest_key, NodeDef* dest) {
    146   CHECK_NE(0, source.attr().count(source_key))
    147       << "No key '" << source_key << "' found in " << source.DebugString();
    148   (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key);
    149 }
    150 
    151 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
    152   TensorProto tensor_proto = node.attr().at(key).tensor();
    153   Tensor tensor;
    154   CHECK(tensor.FromProto(tensor_proto));
    155   return tensor;
    156 }
    157 
    158 void FilterGraphDef(const GraphDef& input_graph_def,
    159                     std::function<bool(const NodeDef&)> selector,
    160                     GraphDef* output_graph_def) {
    161   output_graph_def->mutable_node()->Clear();
    162   for (const NodeDef& node : input_graph_def.node()) {
    163     if (selector(node)) {
    164       *output_graph_def->mutable_node()->Add() = node;
    165     }
    166   }
    167 }
    168 
    169 void RemoveAttributes(const GraphDef& input_graph_def,
    170                       const std::vector<string>& attributes,
    171                       GraphDef* output_graph_def) {
    172   output_graph_def->mutable_node()->Clear();
    173   for (const NodeDef& node : input_graph_def.node()) {
    174     NodeDef* new_node = output_graph_def->mutable_node()->Add();
    175     *new_node = node;
    176     for (const string& attribute : attributes) {
    177       new_node->mutable_attr()->erase(attribute);
    178     }
    179   }
    180 }
    181 
    182 Status SortByExecutionOrder(const GraphDef& input_graph_def,
    183                             GraphDef* output_graph_def) {
    184   const int num_nodes = input_graph_def.node_size();
    185   std::vector<int> ready;
    186   std::vector<int> pending_count;
    187   pending_count.reserve(num_nodes);
    188   std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
    189 
    190   std::map<string, int> name_index;
    191   for (int i = 0; i < input_graph_def.node_size(); ++i) {
    192     const NodeDef& node(input_graph_def.node(i));
    193     name_index[node.name()] = i;
    194   }
    195 
    196   // Parse the inputs for each node.
    197   for (int n = 0; n < num_nodes; ++n) {
    198     const NodeDef& node_def(input_graph_def.node(n));
    199     if (IsMerge(node_def)) {
    200       // for merge only wait for one non-control input.
    201       int32 num_control_edges = 0;
    202       for (int i = 0; i < node_def.input_size(); ++i) {
    203         StringPiece input_name(node_def.input(i));
    204         if (input_name.starts_with("^")) {
    205           num_control_edges++;
    206         }
    207       }
    208       pending_count.push_back(num_control_edges + 1);
    209     } else {
    210       pending_count.push_back(node_def.input_size());
    211     }
    212     if (node_def.input_size() == 0) {
    213       ready.push_back(n);
    214       continue;
    215     }
    216     for (int i = 0; i < node_def.input_size(); ++i) {
    217       const string& input_name = node_def.input(i);
    218       const string& input_node_name = NodeNameFromInput(input_name);
    219       if (!name_index.count(input_node_name)) {
    220         return errors::InvalidArgument("Node '", node_def.name(),
    221                                        "': Unknown input node '",
    222                                        node_def.input(i), "'");
    223       }
    224       outputs[name_index[input_node_name]].push_back(n);
    225     }
    226   }
    227 
    228   int processed = 0;
    229   output_graph_def->Clear();
    230   // Process the NodeDefs in topological order.
    231   // Code above sets this up by filling in ready_ with nodes that have no
    232   // inputs, pending_counts_ with the number of inputs for each node and
    233   // outputs_ with the outputs of each node.
    234   while (!ready.empty()) {
    235     int o = ready.back();
    236     ready.pop_back();
    237     ++processed;
    238     const NodeDef& node_def(input_graph_def.node(o));
    239     *output_graph_def->mutable_node()->Add() = node_def;
    240 
    241     // Update pending_count for outputs.
    242     for (size_t i = 0; i < outputs[o].size(); ++i) {
    243       const int output = outputs[o][i];
    244       pending_count[output]--;
    245       if (pending_count[output] == 0) {
    246         ready.push_back(output);
    247       }
    248     }
    249   }
    250 
    251   if (processed < input_graph_def.node_size()) {
    252     return errors::InvalidArgument(input_graph_def.node_size() - processed,
    253                                    " nodes in a cycle");
    254   }
    255   return Status::OK();
    256 }
    257 
    258 string OpTypePattern::DebugString() const {
    259   string result = "{" + op + ", {";
    260   for (const OpTypePattern& input : inputs) {
    261     result += input.DebugString() + ",";
    262   }
    263   result += "}}";
    264   return result;
    265 }
    266 
    267 string NodeMatch::DebugString() const {
    268   string result = "{";
    269   result += node.DebugString();
    270   result += ", {";
    271   for (const NodeMatch& input : inputs) {
    272     result += input.DebugString() + ",";
    273   }
    274   result += "}}";
    275   return result;
    276 }
    277 
    278 GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
    279   SortByExecutionOrder(graph_def, &graph_def_).IgnoreError();
    280   MapNamesToNodes(graph_def_, &node_map_);
    281 }
    282 
    283 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
    284                                       std::vector<NodeMatch>* matches) {
    285   std::set<string> matched_nodes;
    286   for (const NodeDef& node : graph_def_.node()) {
    287     // Skip any nodes that are already part of a match.
    288     if (matched_nodes.count(node.name())) {
    289       continue;
    290     }
    291     NodeMatch match;
    292     if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
    293       RecordMatchedNodes(match, &matched_nodes);
    294       matches->push_back(match);
    295     }
    296   }
    297   return Status::OK();
    298 }
    299 
    300 bool GraphMatcher::DoesOpTypeMatch(
    301     const NodeDef& node, const OpTypePattern& pattern,
    302     const std::set<string>& previously_matched_nodes, NodeMatch* match) {
    303   VLOG(1) << "Looking at node " << node.DebugString();
    304   VLOG(1) << "pattern=" << pattern.DebugString();
    305   VLOG(1) << "match=" << match->DebugString();
    306   if (previously_matched_nodes.count(node.name())) {
    307     VLOG(1) << "node " << node.name() << " has been previously matched";
    308     return false;
    309   }
    310   bool pattern_matched = false;
    311   if (pattern.op == "*") {
    312     pattern_matched = true;
    313   } else {
    314     std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
    315     for (const string& pattern_op : pattern_ops) {
    316       if (node.op() == pattern_op) {
    317         pattern_matched = true;
    318       }
    319     }
    320   }
    321   if (!pattern_matched) {
    322     VLOG(1) << "node.op() != pattern.op()";
    323     return false;
    324   }
    325   match->node = node;
    326   // Ignore any control inputs for pattern-matching purposes
    327   std::vector<string> non_control_inputs;
    328   for (const string& input : node.input()) {
    329     if (!input.empty() && (input[0] != '^')) {
    330       non_control_inputs.push_back(input);
    331     }
    332   }
    333   if (pattern.inputs.empty()) {
    334     // If there are no inputs, assume that's the end of the pattern.
    335     return true;
    336   }
    337   if (non_control_inputs.size() != pattern.inputs.size()) {
    338     VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
    339     return false;
    340   }
    341   for (int i = 0; i < pattern.inputs.size(); ++i) {
    342     const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
    343     const NodeDef& input_node = *(node_map_[input_node_name]);
    344     const OpTypePattern& input_pattern = pattern.inputs[i];
    345     match->inputs.push_back(NodeMatch());
    346     NodeMatch* input_match = &(match->inputs.back());
    347     if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
    348                          input_match)) {
    349       return false;
    350     }
    351   }
    352   return true;
    353 }
    354 
    355 Status ReplaceMatchingOpTypes(
    356     const GraphDef& input_graph_def, const OpTypePattern& pattern,
    357     const std::function<Status(const NodeMatch&, const std::set<string>&,
    358                                const std::set<string>&, std::vector<NodeDef>*)>&
    359         node_generator,
    360     const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
    361   // Start off by retrieving all the matching subgraphs.
    362   GraphMatcher matcher(input_graph_def);
    363   std::vector<NodeMatch> matches;
    364   TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches));
    365 
    366   // Do some housekeeping so we can easily look up the resulting matches given
    367   // a node name.
    368   std::set<string> matched_nodes;
    369   std::map<string, const NodeMatch*> matches_by_head_name;
    370   for (const NodeMatch& match : matches) {
    371     matches_by_head_name[match.node.name()] = &match;
    372     RecordMatchedNodes(match, &matched_nodes);
    373   }
    374   std::map<string, std::vector<const NodeDef*>> outputs_map;
    375   MapNodesToOutputs(input_graph_def, &outputs_map);
    376 
    377   // Go through all the nodes in the input graph, see if they are part of a
    378   // match or if they can be left untouched.
    379   output_graph_def->Clear();
    380   for (const NodeDef& input_node : input_graph_def.node()) {
    381     if (matches_by_head_name.count(input_node.name())) {
    382       // This node is the beginning of a match, so call the replacement function
    383       // after setting up some information it will need.
    384       const NodeMatch* match = matches_by_head_name[input_node.name()];
    385       std::vector<NodeDef> matched_nodes_array;
    386       MatchedNodesAsArray(*match, &matched_nodes_array);
    387       // This tells us whether a node is part of the current match.
    388       std::set<string> matched_nodes_lookup;
    389       for (const NodeDef& matched_node : matched_nodes_array) {
    390         matched_nodes_lookup.insert(matched_node.name());
    391       }
    392       // These are helper arrays that the replacement function can use to tell
    393       // whether it can safely remove an internal node (because nothing outside
    394       // of the match uses it) or whether external nodes depend on it.
    395       std::set<string> input_nodes;
    396       std::set<string> output_nodes;
    397       for (const NodeDef& matched_node : matched_nodes_array) {
    398         // Look through all of this node's inputs, and if any of them come from
    399         // outside the match, then this should be noted as one of the external
    400         // inputs of the subgraph.
    401         for (const string& input_name : matched_node.input()) {
    402           string input_node_name = NodeNameFromInput(input_name);
    403           if (!matched_nodes_lookup.count(input_node_name)) {
    404             input_nodes.insert(matched_node.name());
    405           }
    406         }
    407         // Do a reverse input lookup, to see which other nodes use the current
    408         // one as an input. If any of those nodes are outside the match
    409         // subgraph, then the current node is marked as an output node that
    410         // shouldn't be removed.
    411         if (outputs_map.count(matched_node.name())) {
    412           for (const NodeDef* dependent_node :
    413                outputs_map[matched_node.name()]) {
    414             if (!matched_nodes_lookup.count(dependent_node->name())) {
    415               output_nodes.insert(matched_node.name());
    416             }
    417           }
    418         }
    419       }
    420       // Call the generator function and add all the returned nodes to the
    421       // graph.
    422       std::vector<NodeDef> new_nodes;
    423       TF_RETURN_IF_ERROR(
    424           node_generator(*match, input_nodes, output_nodes, &new_nodes));
    425       std::set<string> new_node_names;
    426       for (const NodeDef& new_node : new_nodes) {
    427         new_node_names.insert(new_node.name());
    428       }
    429       // Check to make sure the generator function preserved all of the nodes
    430       // that are used elsewhere in the graph, and add them back in if not.
    431       bool abort_replacement = false;
    432       if (!options.allow_inconsistencies) {
    433         for (const string& expected_output : output_nodes) {
    434           if (!new_node_names.count(expected_output)) {
    435             LOG(WARNING) << "Expected " << expected_output
    436                          << " to be preserved.";
    437             abort_replacement = true;
    438           }
    439         }
    440       }
    441       if (abort_replacement) {
    442         LOG(WARNING) << "Generator function didn't preserve needed nodes, "
    443                      << "copying old replacements back in instead.";
    444         std::vector<NodeDef> old_nodes;
    445         MatchedNodesAsArray(*match, &old_nodes);
    446         for (const NodeDef& old_node : old_nodes) {
    447           NodeDef* added_node = output_graph_def->mutable_node()->Add();
    448           *added_node = old_node;
    449         }
    450       } else {
    451         for (const NodeDef& new_node : new_nodes) {
    452           NodeDef* added_node = output_graph_def->mutable_node()->Add();
    453           *added_node = new_node;
    454         }
    455       }
    456     } else if (!matched_nodes.count(input_node.name())) {
    457       // This node isn't part of any match, so just copy it over.
    458       NodeDef* added_node = output_graph_def->mutable_node()->Add();
    459       *added_node = input_node;
    460     } else {
    461       // Do nothing, because this is an internal part of a matching subgraph,
    462       // and so will have been replaced by a new replacement subgraph.
    463     }
    464   }
    465 
    466   return Status::OK();
    467 }
    468 
    469 Status RenameNodeInputs(const GraphDef& input_graph_def,
    470                         const std::map<string, string>& inputs_to_rename,
    471                         const std::unordered_set<string>& nodes_to_ignore,
    472                         GraphDef* output_graph_def) {
    473   std::map<string, std::vector<std::pair<string, string>>>
    474       canonical_inputs_to_rename;
    475   for (const auto& input_to_rename : inputs_to_rename) {
    476     canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
    477         .push_back({input_to_rename.first, input_to_rename.second});
    478   }
    479 
    480   output_graph_def->Clear();
    481   for (const NodeDef& node : input_graph_def.node()) {
    482     NodeDef* new_node = output_graph_def->mutable_node()->Add();
    483     *new_node = node;
    484     new_node->mutable_input()->Clear();
    485     for (const string& input_name : node.input()) {
    486       std::set<string> already_visited;
    487       string new_input_name = input_name;
    488       while (
    489           canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
    490         string input_node_name = NodeNameFromInput(new_input_name);
    491         if (already_visited.count(input_node_name)) {
    492           return errors::InvalidArgument(
    493               "RenameNodeInputs argument contains a cycle for ",
    494               input_node_name);
    495         }
    496         already_visited.insert(input_node_name);
    497         if (nodes_to_ignore.count(node.name())) {
    498           break;
    499         }
    500         bool any_match_found = false;
    501         for (const std::pair<string, string>& input_to_rename :
    502              canonical_inputs_to_rename.at(input_node_name)) {
    503           const string& source_name = input_to_rename.first;
    504           const string& dest_name = input_to_rename.second;
    505           bool is_match;
    506           string match_name;
    507           if (StringPiece(source_name).ends_with(":*")) {
    508             is_match = true;
    509             string prefix;
    510             string unused_node_name;
    511             string suffix;
    512             NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
    513                                    &suffix);
    514             match_name = prefix + dest_name + suffix;
    515           } else {
    516             is_match = (CanonicalInputName(source_name) ==
    517                         CanonicalInputName(new_input_name));
    518             match_name = dest_name;
    519           }
    520           if (is_match) {
    521             new_input_name = match_name;
    522             any_match_found = true;
    523           }
    524         }
    525         if (!any_match_found) {
    526           break;
    527         }
    528       }
    529       *(new_node->mutable_input()->Add()) = new_input_name;
    530     }
    531   }
    532   return Status::OK();
    533 }
    534 
    535 void CopyOriginalMatch(const NodeMatch& match,
    536                        std::vector<NodeDef>* new_nodes) {
    537   std::vector<NodeDef> old_nodes;
    538   MatchedNodesAsArray(match, &old_nodes);
    539   for (const NodeDef& old_node : old_nodes) {
    540     new_nodes->push_back(old_node);
    541   }
    542 }
    543 
    544 TransformRegistry* GetTransformRegistry() {
    545   static TransformRegistry transform_registry;
    546   return &transform_registry;
    547 }
    548 
    549 void FindInvalidInputs(const GraphDef& graph_def,
    550                        std::vector<std::pair<string, string>>* invalid_inputs) {
    551   std::map<string, const NodeDef*> node_map;
    552   MapNamesToNodes(graph_def, &node_map);
    553 
    554   for (const NodeDef& node : graph_def.node()) {
    555     for (const string& input : node.input()) {
    556       string input_node = NodeNameFromInput(input);
    557       if (!node_map.count(input_node)) {
    558         invalid_inputs->push_back({node.name(), input_node});
    559       }
    560     }
    561   }
    562 }
    563 
    564 Status IsGraphValid(const GraphDef& graph_def) {
    565   std::vector<std::pair<string, string>> invalid_inputs;
    566   FindInvalidInputs(graph_def, &invalid_inputs);
    567   if (!invalid_inputs.empty()) {
    568     std::map<string, const NodeDef*> node_map;
    569     MapNamesToNodes(graph_def, &node_map);
    570     for (const std::pair<string, string>& invalid_input : invalid_inputs) {
    571       LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
    572                  << invalid_input.first << " - "
    573                  << node_map[invalid_input.first]->DebugString();
    574     }
    575     return errors::Internal(
    576         "Invalid graph with inputs referring to nonexistent nodes");
    577   }
    578   return Status::OK();
    579 }
    580 
    581 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
    582                      DataTypeVector* outputs) {
    583   const OpDef* op_def;
    584   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
    585   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
    586   return Status::OK();
    587 }
    588 
    589 Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
    590   if (shape_string.empty()) {
    591     return errors::InvalidArgument("Specificed shape is empty.");
    592   }
    593   std::vector<int64> dims;
    594   if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
    595     return errors::InvalidArgument("Could parse as shape: '", shape_string,
    596                                    "'");
    597   }
    598   *result = TensorShape(dims);
    599   return Status::OK();
    600 }
    601 
    602 int TransformFuncContext::CountParameters(const string& name) const {
    603   if (params.count(name)) {
    604     return params.at(name).size();
    605   } else {
    606     return 0;
    607   }
    608 }
    609 
    610 Status TransformFuncContext::GetOneStringParameter(const string& name,
    611                                                    const string& default_value,
    612                                                    string* result) const {
    613   const int params_count = CountParameters(name);
    614   if (params_count == 0) {
    615     *result = default_value;
    616     return Status::OK();
    617   } else if (params_count == 1) {
    618     *result = params.at(name).at(0);
    619     return Status::OK();
    620   } else {
    621     return errors::InvalidArgument("Expected a single '", name,
    622                                    "' parameter, but found ", params_count,
    623                                    " occurrences");
    624   }
    625 }
    626 
    627 Status TransformFuncContext::GetOneInt32Parameter(const string& name,
    628                                                   int32 default_value,
    629                                                   int32* result) const {
    630   const int params_count = CountParameters(name);
    631   if (params_count == 0) {
    632     *result = default_value;
    633     return Status::OK();
    634   }
    635   string string_value;
    636   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
    637   if (!strings::safe_strto32(StringPiece(string_value), result)) {
    638     return errors::InvalidArgument("Couldn't interpret the ", name,
    639                                    " argument as a number:", string_value);
    640   }
    641   return Status::OK();
    642 }
    643 
    644 Status TransformFuncContext::GetOneInt64Parameter(const string& name,
    645                                                   int64 default_value,
    646                                                   int64* result) const {
    647   const int params_count = CountParameters(name);
    648   if (params_count == 0) {
    649     *result = default_value;
    650     return Status::OK();
    651   }
    652   string string_value;
    653   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
    654   if (!strings::safe_strto64(StringPiece(string_value), result)) {
    655     return errors::InvalidArgument("Couldn't interpret the ", name,
    656                                    " argument as a number:", string_value);
    657   }
    658   return Status::OK();
    659 }
    660 
    661 Status TransformFuncContext::GetOneFloatParameter(const string& name,
    662                                                   float default_value,
    663                                                   float* result) const {
    664   const int params_count = CountParameters(name);
    665   if (params_count == 0) {
    666     *result = default_value;
    667     return Status::OK();
    668   }
    669   string string_value;
    670   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
    671   if (!strings::safe_strtof(string_value.c_str(), result)) {
    672     return errors::InvalidArgument(
    673         "Couldn't interpret the ", name,
    674         " argument as a float number:", string_value);
    675   }
    676   return Status::OK();
    677 }
    678 
    679 Status TransformFuncContext::GetOneBoolParameter(const string& name,
    680                                                  bool default_value,
    681                                                  bool* result) const {
    682   const int params_count = CountParameters(name);
    683   if (params_count == 0) {
    684     *result = default_value;
    685     return Status::OK();
    686   }
    687   string string_value;
    688   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
    689   if (string_value == "true" || string_value == "1") {
    690     *result = true;
    691   } else if (string_value == "false" || string_value == "0") {
    692     *result = false;
    693   } else {
    694     return errors::InvalidArgument("Couldn't interpret the ", name,
    695                                    " argument as a boolean:", string_value,
    696                                    " (expected true, false, 0 or 1)");
    697   }
    698   return Status::OK();
    699 }
    700 
    701 }  // namespace graph_transforms
    702 }  // namespace tensorflow
    703