Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/graph_properties.h"
     17 
     18 #include <queue>
     19 #include <unordered_map>
     20 #include <unordered_set>
     21 #include "tensorflow/core/common_runtime/shape_refiner.h"
     22 #include "tensorflow/core/framework/tensor_shape.pb.h"
     23 #include "tensorflow/core/graph/graph_constructor.h"
     24 #include "tensorflow/core/grappler/costs/utils.h"
     25 #include "tensorflow/core/grappler/utils.h"
     26 
     27 namespace tensorflow {
     28 namespace grappler {
     29 namespace {
     30 
     31 using shape_inference::DimensionHandle;
     32 using shape_inference::InferenceContext;
     33 using shape_inference::ShapeAndType;
     34 using shape_inference::ShapeHandle;
     35 
     36 template <typename Handle>
     37 struct HashHandle {
     38   std::size_t operator()(const Handle& h) const { return h.Handle(); }
     39 };
     40 template <typename Handle>
     41 struct CompareHandle {
     42   bool operator()(const Handle& h1, const Handle& h2) const {
     43     return h1.SameHandle(h2);
     44   }
     45 };
     46 
     47 template <typename Handle>
     48 struct HandleToObject {};
     49 template <>
     50 struct HandleToObject<ShapeHandle> {
     51   typedef ShapeHandle Object;
     52 
     53   static ShapeHandle Unknown() { return ShapeHandle(); }
     54 };
     55 
     56 template <>
     57 struct HandleToObject<DimensionHandle> {
     58   typedef int64 Object;
     59 
     60   static int64 Unknown() { return -1; }
     61 };
     62 
     63 template <typename Handle>
     64 struct Processor {};
     65 
     66 template <>
     67 struct Processor<ShapeHandle> {
     68   // Extract the shape or dim denoted by the handle.
     69   void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
     70   // Merge the shapes or dims.
     71   Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
     72     if (InferenceContext::RankKnown(*result)) {
     73       // The result was initialized in a previous merge to a shape of known
     74       // rank, make sure we preserve that information.
     75       return Status::OK();
     76     }
     77     if (InferenceContext::RankKnown(h1)) {
     78       *result = h1;
     79     } else {
     80       *result = h2;
     81     }
     82     return Status::OK();
     83   }
     84 };
     85 
     86 template <>
     87 struct Processor<DimensionHandle> {
     88   // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
     89   // reserved by TensorFlow).
     90   void ExtractValue(DimensionHandle d, int64* result) {
     91     if (!InferenceContext::ValueKnown(d)) {
     92       *result = -counter;
     93       counter++;
     94     } else {
     95       int64 val = InferenceContext::Value(d);
     96       if (val >= 0) {
     97         *result = val;
     98       } else {
     99         // A shape inference function generated an invalid dimension handle.
    100         // Use a symbolic dimension to encode this.
    101         *result = -counter;
    102         counter++;
    103       }
    104     }
    105   }
    106 
    107   // Merge the dimensions d1 and d2. Return the known shape if there is one,
    108   // otherwise look for a symbolic shape. If there is no symbolic shape and no
    109   // known shape, the shape if fully unknown so return -1.
    110   Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
    111     const int64 dim1 = InferenceContext::Value(d1);
    112     const int64 dim2 = InferenceContext::Value(d2);
    113 
    114     if (dim1 >= 0 && dim2 >= 0) {
    115       CHECK_EQ(dim1, dim2);
    116       return RefineDim(dim1, result);
    117     } else if (dim1 >= 0 && dim2 < 0) {
    118       return RefineDim(dim1, result);
    119     } else if (dim1 < 0 && dim2 >= 0) {
    120       return RefineDim(dim2, result);
    121     } else if (dim1 < -1) {
    122       return RefineDim(dim1, result);
    123     } else if (dim2 < -1) {
    124       return RefineDim(dim2, result);
    125     } else {
    126       CHECK_EQ(dim1, dim2);
    127       CHECK_EQ(-1, dim1);
    128       return RefineDim(-1, result);
    129     }
    130     return Status::OK();
    131   }
    132 
    133  private:
    134   Status RefineDim(int64 dim, int64* result) {
    135     if (*result >= 0) {
    136       if (!(*result == dim || dim < 0)) {
    137         return errors::InvalidArgument("Inconsistent dimensions detected");
    138       }
    139     } else if (dim >= 0) {
    140       *result = dim;
    141     } else if (dim < *result) {
    142       *result = dim;
    143     }
    144     return Status::OK();
    145   }
    146 
    147   int64 counter = 2;
    148 };
    149 
    150 // Traditional Disjoint-Set datastructure with path compression.
    151 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
    152 template <typename Handle>
    153 class DisjointSet {
    154  public:
    155   DisjointSet(const Processor<Handle>& processor) : processor_(processor) {}
    156   ~DisjointSet() {
    157     for (auto rep : nodes_) {
    158       delete rep.second;
    159     }
    160   }
    161 
    162   Status Merge(Handle x, Handle y);
    163   const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
    164 
    165  private:
    166   // All the handles that belong to the same set are part of the same tree, and
    167   // utimately represented by the root of that tree.
    168   struct Rep {
    169     // Parent in the tree used to encode the set.
    170     Rep* parent;
    171     // Rank in the tree, used to figure out how to compress the path to the root
    172     // of the tree.
    173     int rank;
    174     // The handle.
    175     typename HandleToObject<Handle>::Object value;
    176   };
    177 
    178   // Create a new set for the value if none exists, or return its representative
    179   // node otherwise.
    180   Rep* Find(Handle value);
    181 
    182  private:
    183   Processor<Handle> processor_;
    184   std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
    185       nodes_;
    186 };
    187 
    188 template <typename Handle>
    189 const typename HandleToObject<Handle>::Object
    190 DisjointSet<Handle>::GetMergedValue(Handle value) {
    191   Rep* rep = Find(value);
    192   if (!rep) {
    193     // We don't know anything about this handle.
    194     return HandleToObject<Handle>::Unknown();
    195   }
    196   return rep->value;
    197 }
    198 
    199 template <typename Handle>
    200 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
    201   Rep* x_root = Find(x);
    202   Rep* y_root = Find(y);
    203 
    204   // x and y are already in the same set
    205   if (x_root == y_root) {
    206     return Status::OK();
    207   }
    208   // x and y are not in same set, so we merge them
    209   // Use the occasion to strengthen what we know about the handle by merging the
    210   // information about the 2 subsets.
    211   if (x_root->rank < y_root->rank) {
    212     TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
    213     x_root->parent = y_root;
    214   } else if (x_root->rank > y_root->rank) {
    215     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
    216     y_root->parent = x_root;
    217   } else {
    218     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
    219     // Arbitrarily make one root the new parent
    220     y_root->parent = x_root;
    221     x_root->rank = x_root->rank + 1;
    222   }
    223   return Status::OK();
    224 }
    225 
    226 template <typename Handle>
    227 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
    228   auto it = nodes_.find(value);
    229   if (it == nodes_.end()) {
    230     // This is the first time we process this handle, create an entry for it.
    231     Rep* node = new Rep;
    232     node->parent = node;
    233     node->rank = 0;
    234     processor_.ExtractValue(value, &node->value);
    235     nodes_[value] = node;
    236     return node;
    237   }
    238   // Return the representative for the set, which is the root of the tree. Apply
    239   // path compression to speedup future queries.
    240   Rep* node = it->second;
    241   Rep* root = node->parent;
    242   while (root != root->parent) {
    243     root = root->parent;
    244   }
    245   while (node->parent != root) {
    246     Rep* next = node->parent;
    247     node->parent = root;
    248     node = next;
    249   }
    250   return root;
    251 }
    252 
    253 bool IsQueue(const Node& node) {
    254   StringPiece type(node.type_string());
    255   return type.ends_with("QueueV2");
    256 }
    257 
    258 // Returns true if the node is an Enter op AND its input is a Queue.
    259 bool IsEnterWithQueue(const Node& node) {
    260   if (node.IsEnter()) {
    261     const Node* in_node;
    262     TF_CHECK_OK(node.input_node(0, &in_node));
    263     return IsQueue(*in_node);
    264   }
    265   return false;
    266 }
    267 
    268 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
    269   if (proto.unknown_rank()) {
    270     return true;
    271   }
    272   for (const auto& dim : proto.dim()) {
    273     if (dim.size() < 0) {
    274       return true;
    275     }
    276   }
    277   return false;
    278 }
    279 
    280 void VerboseLogUnknownDimensionSources(
    281     const Graph& graph,
    282     const std::map<string, std::vector<OpInfo::TensorProperties>>&
    283         input_properties_map,
    284     const std::map<string, std::vector<OpInfo::TensorProperties>>&
    285         output_properties_map) {
    286   if (!VLOG_IS_ON(2)) {
    287     return;
    288   }
    289 
    290   VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
    291 
    292   // Find all nodes in the graph for which we
    293   // do not have any unknown dimensions in their inputs, but
    294   // we have some unknown dimensions in their outputs.
    295   std::map<string, int> op_to_count;
    296   for (const Node* const node : graph.nodes()) {
    297     if (node->num_outputs() == 0) {
    298       continue;
    299     }
    300 
    301     const auto& input_properties = input_properties_map.at(node->name());
    302     const auto& output_properties = output_properties_map.at(node->name());
    303 
    304     bool has_unknown_inputs = false;
    305     for (int i = 0; i < node->num_inputs(); ++i) {
    306       if (HasAnyUnknownDimensions(input_properties[i].shape())) {
    307         has_unknown_inputs = true;
    308         break;
    309       }
    310     }
    311 
    312     if (has_unknown_inputs) {
    313       continue;
    314     }
    315 
    316     for (int i = 0; i < node->num_outputs(); ++i) {
    317       if (HasAnyUnknownDimensions(output_properties[i].shape())) {
    318         string inputs = "input_shapes=[";
    319         for (int i = 0; i < node->num_inputs(); ++i) {
    320           inputs +=
    321               PartialTensorShape::DebugString(input_properties[i].shape());
    322         }
    323         inputs += "]";
    324 
    325         string outputs = "output_shapes=[";
    326         for (int i = 0; i < node->num_outputs(); ++i) {
    327           outputs +=
    328               PartialTensorShape::DebugString(output_properties[i].shape());
    329         }
    330         outputs += "]";
    331 
    332         VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op()
    333                 << ", " << inputs << ", " << outputs;
    334 
    335         op_to_count[node->def().op()]++;
    336 
    337         // don't log again for this node
    338         break;
    339       }
    340     }
    341   }
    342   VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
    343           << "(format: <op_type> (<count>)):";
    344   for (const auto& p : op_to_count) {
    345     VLOG(2) << p.first << " (" << p.second << ")";
    346   }
    347 }
    348 
    349 }  // namespace
    350 
    351 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
    352 // dequeued in (roughly) topological order. Propagating shapes following a
    353 // topological ordering isn't required for correctness but helps speed things up
    354 // since it avoids processing the same node multiple times as its inputs
    355 // information is refined.
    356 class TopoQueue {
    357  public:
    358   void push(const Node* n) { queue_.insert(n); }
    359   const Node* pop() {
    360     CHECK(!empty());
    361     auto it = queue_.begin();
    362     const Node* n = *it;
    363     queue_.erase(it);
    364     return n;
    365   }
    366 
    367   bool empty() const { return queue_.empty(); }
    368   std::size_t size() const { return queue_.size(); }
    369 
    370  private:
    371   // Graph nodes are created in (roughly) topological order. Therefore we can
    372   // use their id to ensure they're sorted topologically.
    373   struct CompareNodes {
    374     bool operator()(const Node* lhs, const Node* rhs) const {
    375       return lhs->id() < rhs->id();
    376     }
    377   };
    378   std::set<const Node*, CompareNodes> queue_;
    379 };
    380 
    381 // Merge and relax symbolic shapes.
    382 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
    383 // shape refiner which creates new handles every time it processes an unknown
    384 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
    385 // unknown shape/dimension of a given node.
    386 class SymbolicShapeRefiner {
    387  public:
    388   explicit SymbolicShapeRefiner(ShapeRefiner* shape_refiner)
    389       : shape_refiner_(shape_refiner) {}
    390 
    391   InferenceContext* GetContext(const Node* node) {
    392     return shape_refiner_->GetContext(node);
    393   }
    394   Status UpdateNode(const Node* node, bool relax, bool* refined) {
    395     return shape_refiner_->UpdateNode(node, relax, refined);
    396   }
    397   Status SetUnknownShape(const Node* node, int output_port) {
    398     shape_inference::ShapeHandle shape =
    399         GetUnknownOutputShape(node, output_port);
    400     InferenceContext* ctx = GetContext(node);
    401     if (ctx == nullptr) {
    402       return errors::InvalidArgument("Missing context");
    403     }
    404     ctx->set_output(output_port, shape);
    405     return Status::OK();
    406   }
    407 
    408   struct ShapeId {
    409     const Node* node;
    410     int port_id;
    411     bool operator==(const ShapeId& other) const {
    412       return node == other.node && port_id == other.port_id;
    413     }
    414   };
    415   struct HashShapeId {
    416     std::size_t operator()(const ShapeId& shp) const {
    417       return std::hash<const Node*>{}(shp.node) + shp.port_id;
    418     }
    419   };
    420 
    421   struct DimId {
    422     const Node* node;
    423     int port_id;
    424     int dim_index;
    425     bool operator==(const DimId& other) const {
    426       return node == other.node && port_id == other.port_id &&
    427              dim_index == other.dim_index;
    428     }
    429   };
    430 
    431   struct HashDimId {
    432     std::size_t operator()(const DimId& dim) const {
    433       return std::hash<const Node*>{}(dim.node) + dim.port_id + dim.dim_index;
    434     }
    435   };
    436 
    437   // Compute the shape of the tensors outputed by node 'node' at output port
    438   // 'port_index' as the intersection of shape1 and shape2.
    439   ShapeHandle OutputAsIntersection(const Node* node, int port_index,
    440                                    ShapeHandle shape1, ShapeHandle shape2) {
    441     if (shape1.SameHandle(shape2)) {
    442       return shape1;
    443     }
    444     InferenceContext* ctx = shape_refiner_->GetContext(node);
    445     ShapeHandle merged = shape1;
    446     if (!ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) {
    447       // Return either one since they're expected to represent the same value.
    448       return shape1;
    449     } else if (!ctx->RankKnown(shape2) && ctx->RankKnown(shape1)) {
    450       return shape1;
    451     } else if (ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) {
    452       return shape2;
    453     } else {
    454       const int rank = ctx->Rank(shape1);
    455       if (ctx->Rank(shape2) != rank) {
    456         // We detected an inconsistency, return an unknown shape. This can
    457         // happen in the fanout of a merge node since during the initial
    458         // propagation we optimistically assume that all the inputs to the merge
    459         // node have the same shape.
    460         return GetUnknownOutputShape(node, port_index);
    461       }
    462       for (int d = 0; d < rank; ++d) {
    463         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
    464           if (ctx->Value(ctx->Dim(shape1, d)) !=
    465               ctx->Value(ctx->Dim(shape2, d))) {
    466             DimensionHandle new_dim;
    467             if (ctx->Value(ctx->Dim(shape1, d)) < 0) {
    468               new_dim = ctx->Dim(shape2, d);
    469             } else if (ctx->Value(ctx->Dim(shape2, d)) < 0) {
    470               new_dim = ctx->Dim(shape1, d);
    471             } else {
    472               new_dim = GetUnknownOutputDim(node, port_index, d);
    473             }
    474             TF_CHECK_OK(ctx->ReplaceDim(merged, d, new_dim, &merged));
    475           }
    476         }
    477       }
    478     }
    479     return merged;
    480   }
    481 
    482   // Compute the shape of the tensors outputed by node 'node' at output port
    483   // 'port_index' as the union of shape1 and shape2.
    484   ShapeHandle OutputAsUnion(const Node* node, int port_index,
    485                             ShapeHandle shape1, ShapeHandle shape2) {
    486     if (shape1.SameHandle(shape2)) {
    487       return shape1;
    488     }
    489     InferenceContext* ctx = shape_refiner_->GetContext(node);
    490     ShapeHandle relaxed = shape1;
    491     const int rank = ctx->Rank(shape1);
    492     if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
    493       relaxed = GetUnknownOutputShape(node, port_index);
    494     } else {
    495       for (int d = 0; d < rank; ++d) {
    496         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
    497           int64 val1 = ctx->Value(ctx->Dim(shape1, d));
    498           int64 val2 = ctx->Value(ctx->Dim(shape2, d));
    499           if (val1 != val2 || (val1 < 0 && val2 < 0)) {
    500             DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
    501             TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
    502           }
    503         }
    504       }
    505     }
    506     return relaxed;
    507   }
    508 
    509   bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
    510     if (s1.SameHandle(s2)) {
    511       return true;
    512     }
    513     if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
    514       return false;
    515     }
    516     if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
    517       return true;
    518     }
    519     const int rank = InferenceContext::Rank(s1);
    520     for (int i = 0; i < rank; ++i) {
    521       if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
    522               InferenceContext::DimKnownRank(s2, i))) {
    523         int64 val1 =
    524             InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
    525         int64 val2 =
    526             InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
    527         if (val1 >= 0 && val2 >= 0 && val1 == val2) {
    528           continue;
    529         }
    530         return false;
    531       }
    532     }
    533     return true;
    534   }
    535 
    536   bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
    537                                 const std::vector<ShapeAndType>& st2) const {
    538     if (st1.size() != st2.size()) {
    539       return false;
    540     }
    541     for (int i = 0; i < st1.size(); ++i) {
    542       const ShapeAndType& s1 = st1[i];
    543       const ShapeAndType& s2 = st2[i];
    544       if (s1.dtype != s2.dtype) {
    545         return false;
    546       }
    547       if (!EquivalentShapes(s1.shape, s2.shape)) {
    548         return false;
    549       }
    550     }
    551     return true;
    552   }
    553 
    554  private:
    555   // Return the one ShapeHandle used to denote a fully unknown shape for a node
    556   // output.
    557   ShapeHandle GetUnknownOutputShape(const Node* node, int index) {
    558     ShapeId id{node, index};
    559     auto it = unknown_shapes_.find(id);
    560     if (it != unknown_shapes_.end()) {
    561       return it->second;
    562     }
    563     InferenceContext* c = shape_refiner_->GetContext(node);
    564     ShapeHandle shp = c->UnknownShape();
    565     unknown_shapes_[id] = shp;
    566     return shp;
    567   }
    568   // Return the one ShapeHandle used to denote a fully unknown dimension for a
    569   // node output.
    570   DimensionHandle GetUnknownOutputDim(const Node* node, int index, int dim_id) {
    571     DimId id{node, index, dim_id};
    572     auto it = unknown_dims_.find(id);
    573     if (it != unknown_dims_.end()) {
    574       return it->second;
    575     }
    576     InferenceContext* c = shape_refiner_->GetContext(node);
    577     DimensionHandle dim = c->UnknownDim();
    578     unknown_dims_[id] = dim;
    579     return dim;
    580   }
    581 
    582   ShapeRefiner* shape_refiner_;
    583 
    584   std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
    585   std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
    586 };
    587 
    588 // Keep track of shapes and dimensions in a graph.
    589 // In particular, use disjoint sets to track equivalence between shapes and
    590 // dims, and consolidate the information globally.
    591 class SymbolicShapeManager {
    592  public:
    593   SymbolicShapeManager() : shapes_(shape_processor_), dims_(dim_processor_) {}
    594 
    595   Status Merge(ShapeHandle s1, ShapeHandle s2) {
    596     if (!s1.IsSet() || !s2.IsSet()) {
    597       return Status::OK();
    598     }
    599     TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
    600     if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
    601       CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
    602       for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
    603         TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
    604                                        InferenceContext::DimKnownRank(s2, i)));
    605       }
    606     }
    607     return Status::OK();
    608   }
    609   Status Merge(DimensionHandle d1, DimensionHandle d2) {
    610     if (!d1.IsSet() || !d2.IsSet()) {
    611       return Status::OK();
    612     }
    613     return dims_.Merge(d1, d2);
    614   }
    615 
    616   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
    617                           OpInfo::TensorProperties* properties) {
    618     properties->set_dtype(type);
    619     ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
    620     if (!InferenceContext::RankKnown(actual_shape)) {
    621       properties->mutable_shape()->set_unknown_rank(true);
    622     } else {
    623       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
    624         shape_inference::DimensionHandle dim =
    625             InferenceContext::DimKnownRank(actual_shape, j);
    626         int64 d = dims_.GetMergedValue(dim);
    627         properties->mutable_shape()->add_dim()->set_size(d);
    628       }
    629     }
    630   }
    631 
    632  private:
    633   Processor<ShapeHandle> shape_processor_;
    634   DisjointSet<shape_inference::ShapeHandle> shapes_;
    635   Processor<DimensionHandle> dim_processor_;
    636   DisjointSet<shape_inference::DimensionHandle> dims_;
    637 };
    638 
    639 Status GraphProperties::MergeEnqueueShapesAndTypes(
    640     SymbolicShapeRefiner* shape_refiner, const Node* qnode,
    641     const std::vector<ShapeAndType>& shapes_and_types,
    642     std::vector<ShapeAndType>* queue_shapes_and_types) {
    643   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
    644     return errors::InvalidArgument(
    645         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
    646         "  vs ", queue_shapes_and_types->size());
    647   }
    648   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
    649     const ShapeAndType& a = shapes_and_types[i];
    650     ShapeAndType& b = (*queue_shapes_and_types)[i];
    651     if (a.dtype != b.dtype) {
    652       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
    653                                      i, ": ", DataTypeString(a.dtype), " vs ",
    654                                      DataTypeString(b.dtype));
    655     }
    656 
    657     b.shape = shape_refiner->OutputAsIntersection(qnode, i, a.shape, b.shape);
    658   }
    659   return Status::OK();
    660 }
    661 
    662 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
    663     SymbolicShapeRefiner* shape_refiner, const Node* qnode,
    664     const std::vector<ShapeAndType>& shapes_and_types,
    665     std::vector<ShapeAndType>* queue_shapes_and_types) {
    666   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
    667     return errors::InvalidArgument(
    668         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
    669         "  vs ", queue_shapes_and_types->size());
    670   }
    671   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
    672     const ShapeAndType& a = shapes_and_types[i];
    673     ShapeAndType& b = (*queue_shapes_and_types)[i];
    674     if (a.dtype != b.dtype) {
    675       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
    676                                      i, ": ", DataTypeString(a.dtype), " vs ",
    677                                      DataTypeString(b.dtype));
    678     }
    679 
    680     b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
    681   }
    682   return Status::OK();
    683 }
    684 
    685 // If a Merge node has a NextIteration node as an input then that input will
    686 // try to forward an UnknownShape at graph construction time. However, the
    687 // Merge shape function will always propagate an UnknownShape if any of its
    688 // inputs are UnknownShapes. So we need to ignore the input from NextIteration
    689 // nodes to propagate any known shape from the Merge node.
    690 Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
    691                                         const Node* node, bool relax,
    692                                         TopoQueue* new_shapes) {
    693   InferenceContext* c = shape_refiner->GetContext(node);
    694   CHECK_NE(c, nullptr);
    695 
    696   ShapeHandle out1;
    697   TF_RETURN_IF_ERROR(c->WithRank(c->output(1), 0, &out1));
    698   c->set_output(1, out1);
    699 
    700   ShapeHandle out;
    701   bool out_initialized = false;
    702   for (const Edge* e : node->in_edges()) {
    703     if (e->IsControlEdge()) {
    704       continue;
    705     }
    706     // Skip back edges during the initial propagation phase. This is equivalent
    707     // to assuming that all the inputs to the merge nodes are fed by the same
    708     // shape, and will be corrected as needed in the relaxation phase.
    709     if (!relax && e->src()->IsNextIteration()) {
    710       continue;
    711     }
    712 
    713     InferenceContext* in = shape_refiner->GetContext(e->src());
    714     ShapeHandle input = in->output(e->src_output());
    715     if (relax) {
    716       c->RelaxInput(e->dst_input(), input);
    717     } else {
    718       c->MergeInput(e->dst_input(), input);
    719     }
    720     if (!out_initialized) {
    721       out_initialized = true;
    722       out = input;
    723       continue;
    724     }
    725     if (relax) {
    726       out = shape_refiner->OutputAsUnion(node, 0, input, out);
    727     } else {
    728       out = shape_refiner->OutputAsIntersection(node, 0, input, out);
    729     }
    730   }
    731 
    732   if (!shape_refiner->EquivalentShapes(out, c->output(0))) {
    733     c->set_output(0, out);
    734     new_shapes->push(node);
    735   }
    736 
    737   return Status::OK();
    738 }
    739 
    740 Status GraphProperties::OverwriteFedPorts(
    741     SymbolicShapeRefiner* shape_refiner,
    742     const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
    743     const Node* node, TopoQueue* new_shapes) const {
    744   auto it = fed_ports.find(node->name());
    745   Status status;
    746   if (it != fed_ports.end()) {
    747     // It is possible to feed node output ports with tensors of any shape: as a
    748     // result, the shape of a fed port is completely unknown.
    749     for (const int output_port : it->second) {
    750       status.Update(shape_refiner->SetUnknownShape(node, output_port));
    751     }
    752     new_shapes->push(node);
    753   }
    754   return status;
    755 }
    756 
    757 // Manually propagate the input shape for Enter nodes and update any Merge node
    758 // outputs.
    759 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
    760                                     const Node* node, bool relax,
    761                                     TopoQueue* new_shapes) {
    762   auto enter_ctx = shape_refiner->GetContext(node);
    763   CHECK_NE(enter_ctx, nullptr);
    764 
    765   for (const Edge* e : node->in_edges()) {
    766     if (e->IsControlEdge()) {
    767       continue;
    768     }
    769     InferenceContext* in = shape_refiner->GetContext(e->src());
    770     ShapeHandle input = in->output(e->src_output());
    771     if (!enter_ctx->output(0).SameHandle(input)) {
    772       if (relax) {
    773         enter_ctx->RelaxInput(0, input);
    774       } else {
    775         enter_ctx->MergeInput(0, input);
    776       }
    777       enter_ctx->set_output(0, input);
    778       new_shapes->push(node);
    779     }
    780   }
    781   return Status::OK();
    782 }
    783 
    784 Status GraphProperties::UpdateShapes(
    785     SymbolicShapeRefiner* shape_refiner, bool relax,
    786     const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
    787     const Node* n, TopoQueue* new_shapes) const {
    788   if (n->IsEnter()) {
    789     // The Enter shape function always forwards an UnknownShape, so do the right
    790     // thing here.
    791     TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, relax, new_shapes));
    792   } else if (n->IsMerge()) {
    793     // Properly handle merge nodes.
    794     TF_RETURN_IF_ERROR(UpdateMergeNode(shape_refiner, n, relax, new_shapes));
    795   } else {
    796     // Rely on regular TF shape refinement for all the other nodes.
    797     bool updated = false;
    798     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, relax, &updated));
    799     if (updated) {
    800       // We want to avoid propagating through loops on the merge pass because
    801       // the shapes are not guaranteed to converge.
    802       if (relax || !n->IsNextIteration()) {
    803         new_shapes->push(n);
    804       }
    805     }
    806   }
    807   // Nodes can be fed with any shape. The TensorFlow shape inference code can't
    808   // handle this properly, so overwrite its behavior here.
    809   return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes);
    810 }
    811 
    812 // Propagates the shapes in the transitive fan-out of <new_shapes>.
    813 Status GraphProperties::PropagateShapes(
    814     SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
    815     const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
    816         resources,
    817     const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
    818     int num_loops) const {
    819   // Limit the number of iterations to prevent infinite loops in the presence of
    820   // incorrect shape functions. The algoritm should converge in at most
    821   // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
    822   // The same applies to resources.
    823   VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size()
    824           << " new shapes through " << num_loops << " loops and "
    825           << resources.size() << " resources" << std::endl;
    826 
    827   const int64 max_loop_length = item_.graph.node_size();
    828   const int64 max_rank = 4;
    829   const int64 max_loop_iterations =
    830       max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
    831   const int64 num_queues = resources.size();
    832   const int64 max_resource_iterations = num_queues * num_queues * max_rank;
    833 
    834   int64 num_resource_iterations = 0;
    835   do {
    836     int64 num_loop_iterations = 0;
    837     while (!new_shapes->empty() &&
    838            num_loop_iterations++ < max_loop_iterations) {
    839       const Node* n = new_shapes->pop();
    840       for (const Edge* e : n->out_edges()) {
    841         if (!e->IsControlEdge()) {
    842           const Node* fanout = e->dst();
    843           TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports,
    844                                           fanout, new_shapes));
    845         }
    846       }
    847     }
    848 
    849     for (const auto& resource : resources) {
    850       // Resources need special handling: since the enqueue nodes are in the
    851       // fanout of the queues, we need to manually propagate the shapes from
    852       // enqueue node to the corresponding queue.
    853       TF_RETURN_IF_ERROR(UpdateResource(resource.first, resource.second,
    854                                         shape_refiner, relax, new_shapes));
    855     }
    856   } while (!new_shapes->empty() &&
    857            num_resource_iterations++ < max_resource_iterations);
    858 
    859   if (!new_shapes->empty()) {
    860     return errors::Internal("Shape inference failed to converge");
    861   }
    862 
    863   return Status::OK();
    864 }
    865 
    866 Status GraphProperties::UpdateResource(
    867     const Node* qnode, const std::unordered_set<const Node*>& queue_inputs,
    868     SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes) {
    869   // Proceed only if qnode is a queue or an Enter with queue input.
    870   if (!IsQueue(*qnode) && !IsEnterWithQueue(*qnode)) {
    871     return Status::OK();
    872   }
    873   auto qctx = shape_refiner->GetContext(qnode);
    874   if (!qctx) {
    875     return Status::OK();
    876   }
    877   auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
    878 
    879   // Merge all inputs into the enqueue node, regardless of which phase we
    880   // are in.
    881   std::vector<ShapeAndType> queue_shapes_and_types;
    882   if (queue_handle_data) {
    883     queue_shapes_and_types = *queue_handle_data;
    884   }
    885   for (const auto& node : queue_inputs) {
    886     auto ctx = shape_refiner->GetContext(node);
    887     if (!ctx) {
    888       continue;
    889     }
    890     // TODO(bsteiner): handle EnqueueMany as well.
    891     if (node->type_string().find("Enqueue") != std::string::npos &&
    892         node->type_string().find("EnqueueMany") == std::string::npos) {
    893       std::vector<ShapeAndType> shapes_and_types;
    894       for (int i = 1; i < ctx->num_inputs(); ++i) {
    895         shapes_and_types.push_back({ctx->input(i), node->input_type(i)});
    896       }
    897       if (queue_shapes_and_types.empty()) {
    898         queue_shapes_and_types = shapes_and_types;
    899       } else {
    900         if (relax) {
    901           TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
    902               shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
    903         } else {
    904           TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes(
    905               shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
    906         }
    907       }
    908     }
    909   }
    910 
    911   if (queue_handle_data == nullptr ||
    912       !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
    913                                                queue_shapes_and_types)) {
    914     qctx->set_output_handle_shapes_and_types(0, queue_shapes_and_types);
    915 
    916     new_shapes->push(qnode);
    917   }
    918 
    919   return Status::OK();
    920 }
    921 
    922 Status GraphProperties::InferStatically(bool assume_valid_feeds) {
    923   Graph graph(OpRegistry::Global());
    924   FunctionLibraryDefinition function_library(graph.op_registry(),
    925                                              item_.graph.library());
    926   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    927   shape_refiner.set_require_shape_inference_fns(false);
    928   shape_refiner.set_disable_constant_propagation(true);
    929   shape_refiner.set_function_library_for_shape_inference(&function_library);
    930   ImportGraphDefOptions options;
    931   // Graph optimization happens at the late stage of graph execution,
    932   // when colocation constraints are already validated previously and
    933   // the device placement of nodes has also completed, so there
    934   // is no need to validate colocation constraints again.
    935   options.validate_colocation_constraints = false;
    936   Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
    937   TF_RETURN_IF_ERROR(s);
    938 
    939   std::unordered_map<string, std::unordered_set<int>> fed_ports;
    940   if (!assume_valid_feeds) {
    941     for (const auto& feed : item_.feed) {
    942       int port_index = 0;
    943       string node_name = ParseNodeName(feed.first, &port_index);
    944       fed_ports[node_name].insert(port_index);
    945     }
    946   }
    947 
    948   // List the resources and the nodes using them. Also collect the Enter and
    949   // Merge nodes.
    950   std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
    951   std::unordered_set<const Node*> enter_nodes;
    952   std::unordered_set<const Node*> merge_nodes;
    953   std::unordered_set<const Node*> fed_nodes;
    954   int num_loops = 0;
    955   for (const Node* const node : graph.nodes()) {
    956     for (int i = 0; i < node->num_inputs(); ++i) {
    957       if (node->input_type(i) == DataType::DT_RESOURCE) {
    958         const Node* resource;
    959         TF_CHECK_OK(node->input_node(i, &resource));
    960         resources[resource].insert(node);
    961       }
    962     }
    963     if (node->IsEnter()) {
    964       enter_nodes.insert(node);
    965     } else if (node->IsMerge()) {
    966       merge_nodes.insert(node);
    967     } else if (node->IsNextIteration()) {
    968       ++num_loops;
    969     }
    970     if (fed_ports.find(node->name()) != fed_ports.end()) {
    971       fed_nodes.insert(node);
    972     }
    973   }
    974 
    975   SymbolicShapeRefiner refiner(&shape_refiner);
    976 
    977   // We propagate shapes through the graph in two phases. In the first phase, we
    978   // exclusively merge shapes but we do not propagate shapes through the
    979   // backedge of loops (i.e. the NextIteration node). Then on the second phase,
    980   // we exclusively relax shapes and propagate shapes through loops until
    981   // reaching fixed point.
    982   for (int relax = 0; relax < 2; relax++) {
    983     TopoQueue new_shapes;
    984     // Force the propagation of shapes of Enter nodes manually (the Enter shape
    985     // function always forwards an UnknownShape).
    986     for (const Node* node : enter_nodes) {
    987       TF_RETURN_IF_ERROR(
    988           UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
    989     }
    990     // Seed the propagation of shapes through merge nodes.
    991     for (const Node* node : merge_nodes) {
    992       TF_RETURN_IF_ERROR(
    993           UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
    994     }
    995     // Also seed the propagation of shapes in the fanout of fed nodes.
    996     for (const Node* node : fed_nodes) {
    997       TF_RETURN_IF_ERROR(
    998           OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes));
    999     }
   1000     // Propagate shapes normally.
   1001     TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources,
   1002                                        fed_ports, num_loops));
   1003   }
   1004 
   1005   // Track shapes globally across the graph.
   1006   SymbolicShapeManager shape_manager;
   1007   bool found_error = false;
   1008   for (const Node* const node : graph.nodes()) {
   1009     auto node_ctx = shape_refiner.GetContext(node);
   1010     if (!node_ctx) {
   1011       continue;
   1012     }
   1013     // Skip any information that comes from fed nodes.
   1014     if (fed_ports.find(node->name()) != fed_ports.end()) {
   1015       continue;
   1016     }
   1017     for (const auto& merged_shapes : node_ctx->MergedShapes()) {
   1018       if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second)
   1019                .ok()) {
   1020         found_error = true;
   1021         break;
   1022       }
   1023     }
   1024     for (const auto& merged_dims : node_ctx->MergedDims()) {
   1025       if (!shape_manager.Merge(merged_dims.first, merged_dims.second).ok()) {
   1026         found_error = true;
   1027         break;
   1028       }
   1029     }
   1030     if (found_error) {
   1031       // The shapes aren't consistent, we can't infer safely: discard all the
   1032       // information discovered so far.
   1033       shape_manager = SymbolicShapeManager();
   1034       break;
   1035     }
   1036   }
   1037 
   1038   for (const Node* const node : graph.nodes()) {
   1039     VLOG(3) << "Filling in graph properties for node: " << node->name();
   1040     auto ctx = shape_refiner.GetContext(node);
   1041     if (!ctx) {
   1042       continue;
   1043     }
   1044 
   1045     // Fill input properties.
   1046     {
   1047       CHECK_EQ(ctx->num_inputs(), node->num_inputs());
   1048       auto& input_properties = input_properties_[node->name()];
   1049 
   1050       // Should always be empty, node names in graph are supposed to be unique.
   1051       CHECK_EQ(input_properties.size(), 0);
   1052 
   1053       input_properties.resize(ctx->num_inputs());
   1054       for (int i = 0; i < ctx->num_inputs(); ++i) {
   1055         shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i),
   1056                                          &input_properties[i]);
   1057       }
   1058       for (const auto& edge : node->in_edges()) {
   1059         if (edge->IsControlEdge()) {
   1060           continue;
   1061         }
   1062         if (!edge->src()->IsConstant()) {
   1063           continue;
   1064         }
   1065         const int input_id = edge->dst_input();
   1066         if (input_id >= input_properties.size()) {
   1067           continue;
   1068         }
   1069         const NodeDef& node = edge->src()->def();
   1070         const TensorProto& raw_val = node.attr().at("value").tensor();
   1071         *input_properties[input_id].mutable_value() = raw_val;
   1072       }
   1073     }
   1074 
   1075     // Fill output properties.
   1076     {
   1077       CHECK_EQ(ctx->num_outputs(), node->num_outputs());
   1078       auto& output_properties = output_properties_[node->name()];
   1079 
   1080       // Should always be empty, node names in graph are supposed to be unique.
   1081       CHECK_EQ(output_properties.size(), 0);
   1082 
   1083       output_properties.resize(ctx->num_outputs());
   1084       for (int i = 0; i < ctx->num_outputs(); ++i) {
   1085         shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i),
   1086                                          &output_properties[i]);
   1087       }
   1088     }
   1089   }
   1090 
   1091   // Help trace the unknown dimensions to their origins.
   1092   VerboseLogUnknownDimensionSources(graph, input_properties_,
   1093                                     output_properties_);
   1094 
   1095   return Status::OK();
   1096 }
   1097 
   1098 Status GraphProperties::InferDynamically(Cluster* cluster) {
   1099   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
   1100 
   1101   // Runs the model once to collect the shapes in the cost model.
   1102   RunMetadata metadata;
   1103   TF_RETURN_IF_ERROR(
   1104       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
   1105 
   1106   return InferFromCostGraph(metadata.cost_graph());
   1107 }
   1108 
   1109 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
   1110   *output_graph_def = item_.graph;
   1111   for (int i = 0; i < output_graph_def->node_size(); i++) {
   1112     auto node = output_graph_def->mutable_node(i);
   1113     AttrValue attr_output_shape;
   1114     auto tensor_properties = GetOutputProperties(node->name());
   1115     for (const auto& tensor_property : tensor_properties) {
   1116       *attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
   1117     }
   1118     (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
   1119   }
   1120   return Status::OK();
   1121 }
   1122 
   1123 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
   1124   if (cost_graph.node_size() == 0) {
   1125     LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
   1126   }
   1127   std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
   1128   std::unordered_map<string, const NodeDef*> name_to_node;  // Empty
   1129   for (auto& node : cost_graph.node()) {
   1130     name_to_cost[node.name()] = &node;
   1131 
   1132     std::vector<OpInfo::TensorProperties> output_properties;
   1133     for (const auto& out : node.output_info()) {
   1134       OpInfo::TensorProperties properties;
   1135       properties.set_dtype(out.dtype());
   1136       *properties.mutable_shape() = out.shape();
   1137       output_properties.push_back(properties);
   1138     }
   1139     output_properties_[node.name()] = output_properties;
   1140   }
   1141 
   1142   for (const auto& node : item_.graph.node()) {
   1143     // Skip the nodes that are not in the cost graph: these are nodes that
   1144     // aren't run, because they aren't in the intersection of transitive fan-in
   1145     // of a fetch node and the transitive fan-out of an input, or nodes that
   1146     // were optimized away by the optimizer.
   1147     auto it = name_to_cost.find(node.name());
   1148     if (it == name_to_cost.end()) {
   1149       continue;
   1150     }
   1151     std::vector<OpInfo::TensorProperties> inputs =
   1152         FindInputFeatures(node, name_to_cost, name_to_node);
   1153 
   1154     input_properties_[node.name()] = inputs;
   1155   }
   1156   return Status::OK();
   1157 }
   1158 
   1159 bool GraphProperties::HasInputProperties(const string& name) const {
   1160   return input_properties_.find(name) != input_properties_.end();
   1161 }
   1162 
   1163 bool GraphProperties::HasOutputProperties(const string& name) const {
   1164   return output_properties_.find(name) != output_properties_.end();
   1165 }
   1166 
   1167 const std::vector<OpInfo::TensorProperties>&
   1168 GraphProperties::GetInputProperties(const string& node_name) const {
   1169   auto it = input_properties_.find(node_name);
   1170   if (it != input_properties_.end()) {
   1171     return it->second;
   1172   }
   1173   return missing_properties_;
   1174 }
   1175 
   1176 const std::vector<OpInfo::TensorProperties>&
   1177 GraphProperties::GetOutputProperties(const string& node_name) const {
   1178   auto it = output_properties_.find(node_name);
   1179   if (it != output_properties_.end()) {
   1180     return it->second;
   1181   }
   1182   return missing_properties_;
   1183 }
   1184 
   1185 }  // end namespace grappler
   1186 }  // end namespace tensorflow
   1187