Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
     17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
     18 
     19 #include <functional>
     20 #include <iterator>
     21 #include <set>
     22 #include <unordered_set>
     23 #include <utility>
     24 #include <vector>
     25 #include "absl/types/span.h"
     26 #include "tensorflow/core/framework/graph.pb.h"
     27 #include "tensorflow/core/framework/node_def.pb.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/graph/tensor_id.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/core/stringpiece.h"
     33 #include "tensorflow/core/lib/core/threadpool.h"
     34 #include "tensorflow/core/lib/gtl/flatmap.h"
     35 #include "tensorflow/core/lib/gtl/flatset.h"
     36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 namespace tensorflow {
     40 namespace grappler {
     41 
     42 // A utility class to lookup a node and its outputs by node name.
     43 class NodeMap {
     44  public:
     45   // Note: The NodeMap will store pointers to nodes in graph, which may become
     46   // invalid if graph is changed.
     47   explicit NodeMap(GraphDef* graph);
     48   NodeDef* GetNode(const string& name) const;
     49   bool NodeExists(const string& name) const;
     50   const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
     51   // This method doesn't record the outputs of the added node; the outputs need
     52   // to be explicitly added by the AddOutput method.
     53   void AddNode(const string& name, NodeDef* node);
     54   void RemoveNode(const string& name);
     55   void UpdateInput(const string& node_name, const string& old_input_name,
     56                    const string& new_input_name);
     57   void AddOutput(const string& node_name, const string& output_name);
     58   void RemoveInputs(const string& node_name);
     59   void RemoveOutput(const string& node_name, const string& output_name);
     60   void RemoveOutputs(const string& node_name);
     61   void UpdateOutput(const string& node_name, const string& old_output_name,
     62                     const string& new_output_name);
     63 
     64  private:
     65   const std::set<NodeDef*> empty_set_;
     66   gtl::FlatMap<string, NodeDef*> nodes_;
     67   gtl::FlatMap<string, std::set<NodeDef*>> outputs_;
     68 };
     69 
     70 // A vector with a set. The set stores the same elements as the vector, and
     71 // quickly answers whether a value is in the vector. Duplicated elements are not
     72 // allowed for now.
     73 template <class T, class Hash = std::hash<T>>
     74 class SetVector {
     75  public:
     76   // Returns false if value already existed in the set, true otherwise.
     77   bool PushBack(const T& value) {
     78     if (!set_.insert(value).second) {
     79       return false;
     80     }
     81     vector_.push_back(value);
     82     return true;
     83   }
     84 
     85   T PopBack() {
     86     T back = vector_.back();
     87     set_.erase(back);
     88     vector_.pop_back();
     89     return back;
     90   }
     91 
     92   bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
     93 
     94   bool Empty() const { return vector_.empty(); }
     95 
     96   void Reserve(int64 size) { vector_.reserve(size); }
     97 
     98  private:
     99   gtl::FlatSet<T, Hash> set_;
    100   std::vector<T> vector_;
    101 };
    102 
    103 // Returns formatted string from TensorId specific to grappler. Specifically,
    104 // for the 0 port (first output), only the node name is returned.
    105 string TensorIdToString(const TensorId& tensor_id);
    106 
    107 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
    108 // the ^ character.
    109 bool IsControlInput(const string& name);
    110 
    111 // True iff tensor index refers to a control input.
    112 bool IsControlInput(const TensorId& tensor_id);
    113 
    114 // True iff 'name1' and 'name2' refer to the same input.
    115 bool IsSameInput(const string& name1, const string& name2);
    116 
    117 // Returns the trailing position number (or zero if no number is present) if
    118 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
    119 // Returns -2 if NodeName(input_name) is not equal to node_name.
    120 // Note: This function is used very heavily, and this hand-optimized
    121 // version is 3-4x faster than the version using Scanner, which it replaced.
    122 // This is worth the reduction in readability.
    123 inline int NodePositionIfSameNode(const string& input_name,
    124                                   const string& node_name) {
    125   if (input_name.empty()) return -2;
    126   const bool is_ctrl = input_name[0] == '^';
    127   auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
    128   auto node_it = node_name.begin();
    129   if (node_name.empty() ||
    130       std::distance(input_it, input_name.end()) < node_name.size()) {
    131     return -2;
    132   }
    133   while (node_it != node_name.end()) {
    134     if (*input_it++ != *node_it++) {
    135       return -2;
    136     }
    137   }
    138   if (input_it == input_name.end()) {
    139     return is_ctrl ? -1 : 0;
    140   } else if (*input_it++ == ':') {
    141     StringPiece remaining(&(*input_it),
    142                           std::distance(input_it, input_name.end()));
    143     int position;
    144     if (!strings::safe_strto32(remaining, &position)) {
    145       return -2;
    146     }
    147     return is_ctrl ? -1 : position;
    148   } else {
    149     return -2;
    150   }
    151 }
    152 
    153 // Return the node name corresponding to 'name' if name is valid, or the empty
    154 // string otherwise.
    155 inline StringPiece NodeNameAsStringPiece(const string& name) {
    156   static const string empty;
    157   if (name.empty()) return StringPiece(empty);
    158   const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
    159   auto end_it = begin_it;
    160   while (end_it != name.end() && *end_it != ':') {
    161     ++end_it;
    162   }
    163   if (end_it != name.end() && *end_it != ':') {
    164     return StringPiece(empty);
    165   }
    166   return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
    167 }
    168 
    169 // Return the node name corresponding to 'name' if name is valid, or the empty
    170 // string otherwise.
    171 inline string NodeName(const string& name) {
    172   return string(NodeNameAsStringPiece(name));
    173 }
    174 
    175 // Returns the node name and position in a single call.
    176 // DEPRECATED(ezhulenev): Use TensorId and ParseTensorName.
    177 inline StringPiece ParseNodeNameAsStringPiece(const string& name,
    178                                               int* position) {
    179   static const string empty;
    180   if (name.empty()) {
    181     *position = 0;
    182     return StringPiece(empty);
    183   }
    184   const bool is_ctrl = name[0] == '^';
    185   const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
    186   *position = is_ctrl ? -1 : 0;
    187   auto end_it = begin_it;
    188   while (end_it != name.end() && *end_it != ':') {
    189     ++end_it;
    190   }
    191   const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
    192   if (end_it != name.end()) {
    193     if (*end_it != ':') {
    194       return StringPiece(empty);
    195     } else if (!is_ctrl) {
    196       ++end_it;
    197       StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
    198       if (!strings::safe_strto32(remaining, position)) {
    199         return StringPiece(empty);
    200       }
    201     }
    202   }
    203   return node_name;
    204 }
    205 
    206 // Returns the node name and position in a single call.
    207 // DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName.
    208 inline string ParseNodeName(const string& name, int* position) {
    209   return string(ParseNodeNameAsStringPiece(name, position));
    210 }
    211 
    212 inline int NodePosition(const string& name) {
    213   int position;
    214   ParseNodeNameAsStringPiece(name, &position);
    215   return position;
    216 }
    217 
    218 // Add a prefix to a node name with a custom delimiter.
    219 string AddPrefixToNodeName(const string& name, const string& prefix,
    220                            const string& delimiter);
    221 
    222 // Add a prefix to a node name.
    223 string AddPrefixToNodeName(const string& name, const string& prefix);
    224 
    225 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
    226 // timeout (in milliseconds) for 'fn' to complete, before returning false.
    227 //
    228 // If returning false, the 'fn' may still continue to execute in the
    229 // thread-pool. It is the responsibility of the caller to reset the thread-pool
    230 // as appropriate.
    231 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
    232                         thread::ThreadPool* thread_pool);
    233 
    234 // Returns the node name prefixed with conventional symbol '^'
    235 // for control dependency, given a NodeDef.
    236 string AsControlDependency(const NodeDef& node);
    237 
    238 // Returns the node name prefixed with conventional symbol '^'
    239 // for control dependency, given a node name
    240 string AsControlDependency(const string& node);
    241 
    242 // Returns true if the node is assigned to run on CPU device.
    243 bool NodeIsOnCpu(const NodeDef* node);
    244 
    245 // Returns true if the node is assigned to run on GPU device.
    246 bool NodeIsOnGpu(const NodeDef* node);
    247 
    248 // Returns the number of outputs of a node according to its OpDef. Note that
    249 // some of the outputs may be unconnected.
    250 int NumOutputs(const NodeDef& node, GraphDef* graph);
    251 
    252 // Returns true iff the node has at least one control input.
    253 bool HasControlInputs(const NodeDef& node);
    254 
    255 // Number of connected non-control inputs.
    256 int NumNonControlInputs(const NodeDef& node);
    257 
    258 // Number of connected non-control outputs.
    259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
    260 
    261 // Number of connected non-control data outputs (Ops that consume output tensor
    262 // data, not just it's shape).
    263 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
    264 
    265 // Removes redundant control inputs from node.
    266 void DedupControlInputs(NodeDef* node);
    267 
    268 // Returns an error if an attribute with the given key does not exist in node.
    269 Status CheckAttrExists(const NodeDef& node, const string& key);
    270 
    271 // Returns an error if attributes with the given keys do not exist in node.
    272 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
    273 
    274 // Returns the data type in attribute `attr_name` of `node`. If that attribute
    275 // doesn't exist, returns DT_INVALID.
    276 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
    277 
    278 // Returns the last node in the simple chain starting at source and traversing
    279 // through the input(0) edge from each node as long as the next node satisfies
    280 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
    281 // will be returned. Example: For the chain
    282 //    source <- a <- b <- ... <- y <- z
    283 // where
    284 //    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
    285 //    pred_fn(z) = false,
    286 // the return value will be a pointer to y.
    287 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
    288                         bool follow_control_input,
    289                         const std::function<bool(const NodeDef&)>& pred_fn);
    290 
    291 // Permute the nodes of graph in place according to the permutation.
    292 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
    293                          bool invert_permutation);
    294 
    295 // Returns Status::OK() if a kernel is registered for node.op() on the device
    296 // type corresponding to node.device().
    297 Status IsKernelRegisteredForNode(const NodeDef& node);
    298 
    299 Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
    300 
    301 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
    302 
    303 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
    304 
    305 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
    306                          GraphDef* graph);
    307 
    308 }  // end namespace grappler
    309 }  // end namespace tensorflow
    310 
    311 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
    312