Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/utils.h"
     17 
     18 #include <iterator>
     19 #include <memory>
     20 #include <queue>
     21 #include <vector>
     22 
     23 #include "absl/strings/match.h"
     24 #include "absl/strings/str_cat.h"
     25 #include "tensorflow/core/framework/attr_value.pb.h"
     26 #include "tensorflow/core/framework/function.h"
     27 #include "tensorflow/core/framework/node_def_util.h"
     28 #include "tensorflow/core/framework/op.h"
     29 #include "tensorflow/core/framework/op_def.pb.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/lib/core/stringpiece.h"
     33 #include "tensorflow/core/lib/strings/numbers.h"
     34 #include "tensorflow/core/lib/strings/scanner.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 #include "tensorflow/core/platform/notification.h"
     37 #include "tensorflow/core/util/device_name_utils.h"
     38 
     39 namespace tensorflow {
     40 namespace grappler {
     41 namespace {
     42 template <typename T>
     43 bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
     44   using RealType = typename Eigen::NumTraits<T>::Real;
     45   if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
     46       value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
     47     return false;
     48   }
     49   tensor->flat<T>()(0) = static_cast<T>(value);
     50   return true;
     51 }
     52 
     53 template <typename T>
     54 bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
     55   using RealType = typename Eigen::NumTraits<T>::Real;
     56   if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
     57       value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
     58     return false;
     59   }
     60   tensor->flat<T>()(0) = static_cast<T>(value);
     61   return true;
     62 }
     63 
     64 // Is 'node' an operator that consumes only the shape of its input, not the
     65 // data itself?
     66 // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
     67 // TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
     68 bool IsShapeConsumer(const NodeDef& node) {
     69   const string& op = node.op();
     70   return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
     71 }
     72 
     73 }  // namespace
     74 
     75 NodeMap::NodeMap(GraphDef* graph) {
     76   CHECK(graph != nullptr);
     77   for (int i = 0; i < graph->node_size(); i++) {
     78     NodeDef* node = graph->mutable_node(i);
     79     const string& node_name = node->name();
     80     auto rslt = nodes_.emplace(node_name, node);
     81     // Check that the graph doesn't contain multiple nodes with the same name.
     82     if (!rslt.second) {
     83       LOG(WARNING) << "Duplicated node in the graph: " << node_name;
     84     }
     85     for (const auto& input : node->input()) {
     86       outputs_[NodeName(input)].insert(nodes_[node_name]);
     87     }
     88   }
     89 }
     90 
     91 void NodeMap::RemoveNode(const string& name) {
     92   nodes_.erase(NodeName(name));
     93   outputs_.erase(NodeName(name));
     94 }
     95 
     96 NodeDef* NodeMap::GetNode(const string& name) const {
     97   const string node_name = NodeName(name);
     98   auto it = nodes_.find(node_name);
     99   if (it == nodes_.end()) {
    100     return nullptr;
    101   }
    102   return it->second;
    103 }
    104 
    105 bool NodeMap::NodeExists(const string& name) const {
    106   const string node_name = NodeName(name);
    107   return nodes_.find(node_name) != nodes_.end();
    108 }
    109 
    110 const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
    111   auto it = outputs_.find(node_name);
    112   if (it == outputs_.end()) {
    113     return empty_set_;
    114   }
    115   return it->second;
    116 }
    117 
    118 void NodeMap::AddNode(const string& node_name, NodeDef* node) {
    119   auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
    120   CHECK(ret.second) << "Pair (" << node_name << "," << node
    121                     << ") is not inserted because the same key already exists.";
    122 }
    123 
    124 void NodeMap::AddOutput(const string& node_name, const string& output_name) {
    125   auto output_node = nodes_[NodeName(output_name)];
    126   CHECK(output_node) << "Output node " << output_name
    127                      << " is missing in NodeMap.";
    128   outputs_[node_name].insert(output_node);
    129 }
    130 
    131 void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
    132   outputs_[node_name].erase(nodes_[NodeName(output_name)]);
    133 }
    134 
    135 void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
    136                           const string& new_input_name) {
    137   RemoveOutput(NodeName(old_input_name), node_name);
    138   AddOutput(NodeName(new_input_name), node_name);
    139 }
    140 
    141 void NodeMap::RemoveInputs(const string& node_name) {
    142   auto node = nodes_[node_name];
    143   for (const auto& input : node->input()) {
    144     RemoveOutput(NodeName(input), node->name());
    145   }
    146 }
    147 
    148 void NodeMap::RemoveOutputs(const string& node_name) {
    149   outputs_.erase(node_name);
    150 }
    151 
    152 void NodeMap::UpdateOutput(const string& node_name,
    153                            const string& old_output_name,
    154                            const string& new_output_name) {
    155   std::set<NodeDef*>& outputs = outputs_[node_name];
    156   outputs.erase(nodes_[NodeName(old_output_name)]);
    157   outputs.insert(nodes_[NodeName(new_output_name)]);
    158 }
    159 
    160 string TensorIdToString(const TensorId& tensor_id) {
    161   return tensor_id.index() == 0 ? string(tensor_id.node())
    162                                 : tensor_id.ToString();
    163 }
    164 
    165 bool IsSameInput(const string& name1, const string& name2) {
    166   if (name1 == name2) return true;
    167   TensorId tensor1 = ParseTensorName(name1);
    168   TensorId tensor2 = ParseTensorName(name2);
    169   return tensor1 == tensor2;
    170 }
    171 
    172 bool IsControlInput(const string& name) {
    173   return !name.empty() && name[0] == '^';
    174 }
    175 
    176 bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
    177 
    178 string AddPrefixToNodeName(const string& name, const string& prefix,
    179                            const string& delimiter) {
    180   if (!name.empty()) {
    181     if (name[0] == '^') {
    182       return absl::StrCat("^", prefix, delimiter, name.substr(1));
    183     }
    184   }
    185   return absl::StrCat(prefix, delimiter, name);
    186 }
    187 
    188 string AddPrefixToNodeName(const string& name, const string& prefix) {
    189   return AddPrefixToNodeName(name, prefix, "/");
    190 }
    191 
    192 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
    193                         thread::ThreadPool* const thread_pool) {
    194   if (timeout_in_ms <= 0) {
    195     fn();
    196     return true;
    197   }
    198   auto done = std::make_shared<Notification>();
    199   thread_pool->Schedule([done, fn]() {
    200     fn();
    201     done->Notify();
    202   });
    203   const bool notified =
    204       WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
    205   return notified;
    206 }
    207 
    208 string AsControlDependency(const NodeDef& node) {
    209   return absl::StrCat("^", node.name());
    210 }
    211 
    212 string AsControlDependency(const string& node_name) {
    213   CHECK(!node_name.empty());
    214   return (!node_name.empty() && node_name[0] == '^')
    215              ? node_name
    216              : absl::StrCat("^", node_name);
    217 }
    218 
    219 bool NodeIsOnCpu(const NodeDef* node) {
    220   string task, device;
    221   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
    222          absl::StartsWith(device, DEVICE_CPU);
    223 }
    224 
    225 bool NodeIsOnGpu(const NodeDef* node) {
    226   string task, device;
    227   return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
    228          absl::StartsWith(device, DEVICE_GPU);
    229 }
    230 
    231 int NumOutputs(const NodeDef& node, GraphDef* graph) {
    232   int num_outputs = 0;
    233   const OpDef* op_def = nullptr;
    234   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
    235   if (status.ok()) {
    236     for (const auto& output : op_def->output_arg()) {
    237       if (!output.type_list_attr().empty()) {
    238         num_outputs +=
    239             node.attr().at(output.type_list_attr()).list().type_size();
    240       } else if (!output.number_attr().empty()) {
    241         num_outputs += node.attr().at(output.number_attr()).i();
    242       } else {
    243         num_outputs++;
    244       }
    245     }
    246   } else {
    247     FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
    248     auto status = fdef.LookUpOpDef(node.op(), &op_def);
    249     if (status.ok()) {
    250       num_outputs = op_def->output_arg_size();
    251     }
    252   }
    253   return num_outputs;
    254 }
    255 
    256 bool HasControlInputs(const NodeDef& node) {
    257   int num_inputs = node.input_size();
    258   if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
    259     return true;
    260   }
    261   return false;
    262 }
    263 
    264 int NumNonControlInputs(const NodeDef& node) {
    265   int num_inputs = node.input_size();
    266   for (const string& input : node.input()) {
    267     if (IsControlInput(input)) {
    268       --num_inputs;
    269     }
    270   }
    271   return num_inputs;
    272 }
    273 
    274 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
    275   int num_outputs = 0;
    276   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
    277     for (const string& node_as_input : output->input()) {
    278       if (IsControlInput(node_as_input)) {
    279         break;
    280       }
    281       if (node_as_input == node.name()) {
    282         ++num_outputs;
    283       } else {
    284         const TensorId tensor = ParseTensorName(node_as_input);
    285         if (tensor.node() == node.name()) {
    286           ++num_outputs;
    287         }
    288       }
    289     }
    290   }
    291   return num_outputs;
    292 }
    293 
    294 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
    295   int num_data_outputs = 0;
    296   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
    297     if (IsShapeConsumer(*output)) continue;
    298 
    299     for (int i = 0; i < output->input_size(); ++i) {
    300       const string& input = output->input(i);
    301       if (!IsControlInput(input) && NodeName(input) == node.name()) {
    302         ++num_data_outputs;
    303         break;
    304       }
    305     }
    306   }
    307   return num_data_outputs;
    308 }
    309 
    310 // Returns the data type in attribute `attr_name` of `node`. If that attribute
    311 // doesn't exist, returns DT_INVALID.
    312 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
    313   if (!node.attr().count(type_attr)) {
    314     return DT_INVALID;
    315   }
    316   const auto& attr = node.attr().at(type_attr);
    317   if (attr.value_case() != AttrValue::kType) {
    318     return DT_INVALID;
    319   }
    320   return attr.type();
    321 }
    322 
    323 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
    324                         bool follow_control_input,
    325                         const std::function<bool(const NodeDef&)>& pred_fn) {
    326   const NodeDef* current = &source;
    327   const NodeDef* next = current;
    328   while (next == &source || (next != nullptr && pred_fn(*next))) {
    329     current = next;
    330     if (current->input_size() == 0 ||
    331         (!follow_control_input && IsControlInput(current->input(0)))) {
    332       break;
    333     }
    334     next = node_map.GetNode(current->input(0));
    335     if (next == nullptr) {
    336       LOG(ERROR) << "Node not found: " << current->input(0);
    337     }
    338   }
    339   return const_cast<NodeDef*>(current);
    340 }
    341 
    342 // Every permutation is a product of one or more cycles. Iterate over the cycles
    343 // in the permutation, and convert each of those into a product of
    344 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
    345 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
    346                          bool invert_permutation) {
    347   CHECK_EQ(graph->node_size(), permutation->size());
    348   std::vector<int> inv_perm(permutation->size(), 0);
    349   if (invert_permutation) {
    350     for (size_t n = 0; n < permutation->size(); ++n) {
    351       inv_perm[(*permutation)[n]] = n;
    352     }
    353     permutation->swap(inv_perm);
    354   }
    355   for (std::size_t n = 0; n + 1 < permutation->size(); ++n) {
    356     while (n != (*permutation)[n]) {
    357       std::size_t r = (*permutation)[n];
    358       graph->mutable_node()->SwapElements(n, r);
    359       std::swap((*permutation)[n], (*permutation)[r]);
    360     }
    361   }
    362 }
    363 
    364 void DedupControlInputs(NodeDef* node) {
    365   std::unordered_set<string> inputs;
    366   int pos = 0;
    367   while (pos < node->input_size()) {
    368     const string& input = node->input(pos);
    369     if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
    370       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
    371       node->mutable_input()->RemoveLast();
    372     } else {
    373       ++pos;
    374     }
    375   }
    376 }
    377 
    378 namespace {
    379 
    380 template <typename UniqueContainer>
    381 void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
    382                              GraphDef* graph) {
    383   static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
    384                 "Need to pass container of ints");
    385 
    386   int last = graph->node_size() - 1;
    387   for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
    388     const int index = *it;
    389     graph->mutable_node()->SwapElements(index, last);
    390     last--;
    391   }
    392   graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
    393 }
    394 
    395 template <typename T>
    396 inline void STLSortAndRemoveDuplicates(T* v) {
    397   std::sort(v->begin(), v->end());
    398   v->erase(std::unique(v->begin(), v->end()), v->end());
    399 }
    400 
    401 }  // namespace
    402 
    403 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
    404                          GraphDef* graph) {
    405   EraseNodesFromGraphImpl(nodes_to_delete, graph);
    406 }
    407 
    408 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
    409   STLSortAndRemoveDuplicates(&nodes_to_delete);
    410   EraseNodesFromGraphImpl(nodes_to_delete, graph);
    411 }
    412 
    413 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
    414                          GraphDef* graph) {
    415   std::vector<int> nodes_idx_to_delete;
    416   nodes_idx_to_delete.reserve(nodes_to_delete.size());
    417   for (int i = 0; i < graph->node_size(); ++i) {
    418     if (nodes_to_delete.count(graph->node(i).name()))
    419       nodes_idx_to_delete.push_back(i);
    420   }
    421   EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
    422 }
    423 
    424 #define HANDLE_DOUBLE_CASE(DTYPE)                                     \
    425   case DTYPE:                                                         \
    426     if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
    427             static_cast<double>(value), tensor)) {                    \
    428       return errors::InvalidArgument("Cannot store value ", value,    \
    429                                      " in tensor of type " #DTYPE);   \
    430     }                                                                 \
    431     break
    432 
    433 #define HANDLE_INT_CASE(DTYPE)                                               \
    434   case DTYPE:                                                                \
    435     if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value,     \
    436                                                                   tensor)) { \
    437       return errors::InvalidArgument("Cannot store value ", value,           \
    438                                      " in tensor of type " #DTYPE);          \
    439     }                                                                        \
    440     break
    441 
    442 Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
    443   // TODO(rmlarsen): Support more general shapes.
    444   // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
    445   if (tensor->NumElements() != 1) {
    446     return errors::InvalidArgument(
    447         "Expected scalar tensor, got num_elements = ", tensor->NumElements());
    448   }
    449   switch (dtype) {
    450     HANDLE_DOUBLE_CASE(DT_HALF);
    451     HANDLE_DOUBLE_CASE(DT_BFLOAT16);
    452     HANDLE_DOUBLE_CASE(DT_BOOL);
    453     HANDLE_DOUBLE_CASE(DT_FLOAT);
    454     HANDLE_DOUBLE_CASE(DT_DOUBLE);
    455     HANDLE_DOUBLE_CASE(DT_UINT8);
    456     HANDLE_DOUBLE_CASE(DT_INT8);
    457     HANDLE_DOUBLE_CASE(DT_UINT16);
    458     HANDLE_DOUBLE_CASE(DT_INT16);
    459     HANDLE_DOUBLE_CASE(DT_INT32);
    460     HANDLE_DOUBLE_CASE(DT_INT64);
    461     HANDLE_DOUBLE_CASE(DT_COMPLEX64);
    462     HANDLE_DOUBLE_CASE(DT_COMPLEX128);
    463     HANDLE_INT_CASE(DT_QINT8);
    464     HANDLE_INT_CASE(DT_QUINT8);
    465     HANDLE_INT_CASE(DT_QINT16);
    466     HANDLE_INT_CASE(DT_QUINT16);
    467     HANDLE_INT_CASE(DT_QINT32);
    468     default:
    469       return errors::InvalidArgument("Unsupported type ",
    470                                      DataTypeString(dtype));
    471   }
    472   return Status::OK();
    473 }
    474 
    475 #undef HANDLE_CASE
    476 
    477 Status CheckAttrExists(const NodeDef& node, const string& key) {
    478   if (!HasNodeAttr(node, key)) {
    479     return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
    480                                    "' attr: ", node.ShortDebugString());
    481   }
    482   return Status::OK();
    483 }
    484 
    485 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
    486   for (const string& key : keys) {
    487     TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
    488   }
    489   return Status::OK();
    490 }
    491 
    492 Status IsKernelRegisteredForNode(const NodeDef& node) {
    493   DeviceNameUtils::ParsedName parsed_name;
    494   if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
    495     return errors::InvalidArgument("Could not parse device name: ",
    496                                    node.device());
    497   }
    498   return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr);
    499 }
    500 
    501 }  // end namespace grappler
    502 }  // end namespace tensorflow
    503