1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/graph_properties.h" 17 18 #include <queue> 19 #include <unordered_map> 20 #include <unordered_set> 21 #include "tensorflow/core/common_runtime/shape_refiner.h" 22 #include "tensorflow/core/framework/tensor_shape.pb.h" 23 #include "tensorflow/core/graph/graph_constructor.h" 24 #include "tensorflow/core/grappler/costs/utils.h" 25 #include "tensorflow/core/grappler/utils.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 namespace { 30 31 using shape_inference::DimensionHandle; 32 using shape_inference::InferenceContext; 33 using shape_inference::ShapeAndType; 34 using shape_inference::ShapeHandle; 35 36 template <typename Handle> 37 struct HashHandle { 38 std::size_t operator()(const Handle& h) const { return h.Handle(); } 39 }; 40 template <typename Handle> 41 struct CompareHandle { 42 bool operator()(const Handle& h1, const Handle& h2) const { 43 return h1.SameHandle(h2); 44 } 45 }; 46 47 template <typename Handle> 48 struct HandleToObject {}; 49 template <> 50 struct HandleToObject<ShapeHandle> { 51 typedef ShapeHandle Object; 52 53 static ShapeHandle Unknown() { return ShapeHandle(); } 54 }; 55 56 template <> 57 struct HandleToObject<DimensionHandle> { 58 typedef int64 Object; 59 60 static int64 Unknown() { return -1; } 61 }; 62 63 template <typename Handle> 64 struct Processor {}; 65 66 template <> 67 struct Processor<ShapeHandle> { 68 // Extract the shape or dim denoted by the handle. 69 void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; } 70 // Merge the shapes or dims. 71 Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) { 72 if (InferenceContext::RankKnown(*result)) { 73 // The result was initialized in a previous merge to a shape of known 74 // rank, make sure we preserve that information. 75 return Status::OK(); 76 } 77 if (InferenceContext::RankKnown(h1)) { 78 *result = h1; 79 } else { 80 *result = h2; 81 } 82 return Status::OK(); 83 } 84 }; 85 86 template <> 87 struct Processor<DimensionHandle> { 88 // Assign a negative id to unknown dimensions, starting at -2 (the -1 id 89 // reserved by TensorFlow). 90 void ExtractValue(DimensionHandle d, int64* result) { 91 if (!InferenceContext::ValueKnown(d)) { 92 *result = -counter; 93 counter++; 94 } else { 95 int64 val = InferenceContext::Value(d); 96 if (val >= 0) { 97 *result = val; 98 } else { 99 // A shape inference function generated an invalid dimension handle. 100 // Use a symbolic dimension to encode this. 101 *result = -counter; 102 counter++; 103 } 104 } 105 } 106 107 // Merge the dimensions d1 and d2. Return the known shape if there is one, 108 // otherwise look for a symbolic shape. If there is no symbolic shape and no 109 // known shape, the shape if fully unknown so return -1. 110 Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) { 111 const int64 dim1 = InferenceContext::Value(d1); 112 const int64 dim2 = InferenceContext::Value(d2); 113 114 if (dim1 >= 0 && dim2 >= 0) { 115 CHECK_EQ(dim1, dim2); 116 return RefineDim(dim1, result); 117 } else if (dim1 >= 0 && dim2 < 0) { 118 return RefineDim(dim1, result); 119 } else if (dim1 < 0 && dim2 >= 0) { 120 return RefineDim(dim2, result); 121 } else if (dim1 < -1) { 122 return RefineDim(dim1, result); 123 } else if (dim2 < -1) { 124 return RefineDim(dim2, result); 125 } else { 126 CHECK_EQ(dim1, dim2); 127 CHECK_EQ(-1, dim1); 128 return RefineDim(-1, result); 129 } 130 return Status::OK(); 131 } 132 133 private: 134 Status RefineDim(int64 dim, int64* result) { 135 if (*result >= 0) { 136 if (!(*result == dim || dim < 0)) { 137 return errors::InvalidArgument("Inconsistent dimensions detected"); 138 } 139 } else if (dim >= 0) { 140 *result = dim; 141 } else if (dim < *result) { 142 *result = dim; 143 } 144 return Status::OK(); 145 } 146 147 int64 counter = 2; 148 }; 149 150 // Traditional Disjoint-Set datastructure with path compression. 151 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure) 152 template <typename Handle> 153 class DisjointSet { 154 public: 155 DisjointSet(const Processor<Handle>& processor) : processor_(processor) {} 156 ~DisjointSet() { 157 for (auto rep : nodes_) { 158 delete rep.second; 159 } 160 } 161 162 Status Merge(Handle x, Handle y); 163 const typename HandleToObject<Handle>::Object GetMergedValue(Handle value); 164 165 private: 166 // All the handles that belong to the same set are part of the same tree, and 167 // utimately represented by the root of that tree. 168 struct Rep { 169 // Parent in the tree used to encode the set. 170 Rep* parent; 171 // Rank in the tree, used to figure out how to compress the path to the root 172 // of the tree. 173 int rank; 174 // The handle. 175 typename HandleToObject<Handle>::Object value; 176 }; 177 178 // Create a new set for the value if none exists, or return its representative 179 // node otherwise. 180 Rep* Find(Handle value); 181 182 private: 183 Processor<Handle> processor_; 184 std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>> 185 nodes_; 186 }; 187 188 template <typename Handle> 189 const typename HandleToObject<Handle>::Object 190 DisjointSet<Handle>::GetMergedValue(Handle value) { 191 Rep* rep = Find(value); 192 if (!rep) { 193 // We don't know anything about this handle. 194 return HandleToObject<Handle>::Unknown(); 195 } 196 return rep->value; 197 } 198 199 template <typename Handle> 200 Status DisjointSet<Handle>::Merge(Handle x, Handle y) { 201 Rep* x_root = Find(x); 202 Rep* y_root = Find(y); 203 204 // x and y are already in the same set 205 if (x_root == y_root) { 206 return Status::OK(); 207 } 208 // x and y are not in same set, so we merge them 209 // Use the occasion to strengthen what we know about the handle by merging the 210 // information about the 2 subsets. 211 if (x_root->rank < y_root->rank) { 212 TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value)); 213 x_root->parent = y_root; 214 } else if (x_root->rank > y_root->rank) { 215 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value)); 216 y_root->parent = x_root; 217 } else { 218 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value)); 219 // Arbitrarily make one root the new parent 220 y_root->parent = x_root; 221 x_root->rank = x_root->rank + 1; 222 } 223 return Status::OK(); 224 } 225 226 template <typename Handle> 227 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) { 228 auto it = nodes_.find(value); 229 if (it == nodes_.end()) { 230 // This is the first time we process this handle, create an entry for it. 231 Rep* node = new Rep; 232 node->parent = node; 233 node->rank = 0; 234 processor_.ExtractValue(value, &node->value); 235 nodes_[value] = node; 236 return node; 237 } 238 // Return the representative for the set, which is the root of the tree. Apply 239 // path compression to speedup future queries. 240 Rep* node = it->second; 241 Rep* root = node->parent; 242 while (root != root->parent) { 243 root = root->parent; 244 } 245 while (node->parent != root) { 246 Rep* next = node->parent; 247 node->parent = root; 248 node = next; 249 } 250 return root; 251 } 252 253 bool IsQueue(const Node& node) { 254 StringPiece type(node.type_string()); 255 return type.ends_with("QueueV2"); 256 } 257 258 // Returns true if the node is an Enter op AND its input is a Queue. 259 bool IsEnterWithQueue(const Node& node) { 260 if (node.IsEnter()) { 261 const Node* in_node; 262 TF_CHECK_OK(node.input_node(0, &in_node)); 263 return IsQueue(*in_node); 264 } 265 return false; 266 } 267 268 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { 269 if (proto.unknown_rank()) { 270 return true; 271 } 272 for (const auto& dim : proto.dim()) { 273 if (dim.size() < 0) { 274 return true; 275 } 276 } 277 return false; 278 } 279 280 void VerboseLogUnknownDimensionSources( 281 const Graph& graph, 282 const std::map<string, std::vector<OpInfo::TensorProperties>>& 283 input_properties_map, 284 const std::map<string, std::vector<OpInfo::TensorProperties>>& 285 output_properties_map) { 286 if (!VLOG_IS_ON(2)) { 287 return; 288 } 289 290 VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:"; 291 292 // Find all nodes in the graph for which we 293 // do not have any unknown dimensions in their inputs, but 294 // we have some unknown dimensions in their outputs. 295 std::map<string, int> op_to_count; 296 for (const Node* const node : graph.nodes()) { 297 if (node->num_outputs() == 0) { 298 continue; 299 } 300 301 const auto& input_properties = input_properties_map.at(node->name()); 302 const auto& output_properties = output_properties_map.at(node->name()); 303 304 bool has_unknown_inputs = false; 305 for (int i = 0; i < node->num_inputs(); ++i) { 306 if (HasAnyUnknownDimensions(input_properties[i].shape())) { 307 has_unknown_inputs = true; 308 break; 309 } 310 } 311 312 if (has_unknown_inputs) { 313 continue; 314 } 315 316 for (int i = 0; i < node->num_outputs(); ++i) { 317 if (HasAnyUnknownDimensions(output_properties[i].shape())) { 318 string inputs = "input_shapes=["; 319 for (int i = 0; i < node->num_inputs(); ++i) { 320 inputs += 321 PartialTensorShape::DebugString(input_properties[i].shape()); 322 } 323 inputs += "]"; 324 325 string outputs = "output_shapes=["; 326 for (int i = 0; i < node->num_outputs(); ++i) { 327 outputs += 328 PartialTensorShape::DebugString(output_properties[i].shape()); 329 } 330 outputs += "]"; 331 332 VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op() 333 << ", " << inputs << ", " << outputs; 334 335 op_to_count[node->def().op()]++; 336 337 // don't log again for this node 338 break; 339 } 340 } 341 } 342 VLOG(2) << "Op types with known inputs, but with unknown output dimensions " 343 << "(format: <op_type> (<count>)):"; 344 for (const auto& p : op_to_count) { 345 VLOG(2) << p.first << " (" << p.second << ")"; 346 } 347 } 348 349 } // namespace 350 351 // Queue of nodes to process. Nodes can be enqueued in any order, but will be 352 // dequeued in (roughly) topological order. Propagating shapes following a 353 // topological ordering isn't required for correctness but helps speed things up 354 // since it avoids processing the same node multiple times as its inputs 355 // information is refined. 356 class TopoQueue { 357 public: 358 void push(const Node* n) { queue_.insert(n); } 359 const Node* pop() { 360 CHECK(!empty()); 361 auto it = queue_.begin(); 362 const Node* n = *it; 363 queue_.erase(it); 364 return n; 365 } 366 367 bool empty() const { return queue_.empty(); } 368 std::size_t size() const { return queue_.size(); } 369 370 private: 371 // Graph nodes are created in (roughly) topological order. Therefore we can 372 // use their id to ensure they're sorted topologically. 373 struct CompareNodes { 374 bool operator()(const Node* lhs, const Node* rhs) const { 375 return lhs->id() < rhs->id(); 376 } 377 }; 378 std::set<const Node*, CompareNodes> queue_; 379 }; 380 381 // Merge and relax symbolic shapes. 382 // Each symbolic shape or dimension is represented by a handle. Unlike the TF 383 // shape refiner which creates new handles every time it processes an unknown 384 // shape/dimension, the symbolic shape refiner assigns a specific handle to each 385 // unknown shape/dimension of a given node. 386 class SymbolicShapeRefiner { 387 public: 388 explicit SymbolicShapeRefiner(ShapeRefiner* shape_refiner) 389 : shape_refiner_(shape_refiner) {} 390 391 InferenceContext* GetContext(const Node* node) { 392 return shape_refiner_->GetContext(node); 393 } 394 Status UpdateNode(const Node* node, bool relax, bool* refined) { 395 return shape_refiner_->UpdateNode(node, relax, refined); 396 } 397 Status SetUnknownShape(const Node* node, int output_port) { 398 shape_inference::ShapeHandle shape = 399 GetUnknownOutputShape(node, output_port); 400 InferenceContext* ctx = GetContext(node); 401 if (ctx == nullptr) { 402 return errors::InvalidArgument("Missing context"); 403 } 404 ctx->set_output(output_port, shape); 405 return Status::OK(); 406 } 407 408 struct ShapeId { 409 const Node* node; 410 int port_id; 411 bool operator==(const ShapeId& other) const { 412 return node == other.node && port_id == other.port_id; 413 } 414 }; 415 struct HashShapeId { 416 std::size_t operator()(const ShapeId& shp) const { 417 return std::hash<const Node*>{}(shp.node) + shp.port_id; 418 } 419 }; 420 421 struct DimId { 422 const Node* node; 423 int port_id; 424 int dim_index; 425 bool operator==(const DimId& other) const { 426 return node == other.node && port_id == other.port_id && 427 dim_index == other.dim_index; 428 } 429 }; 430 431 struct HashDimId { 432 std::size_t operator()(const DimId& dim) const { 433 return std::hash<const Node*>{}(dim.node) + dim.port_id + dim.dim_index; 434 } 435 }; 436 437 // Compute the shape of the tensors outputed by node 'node' at output port 438 // 'port_index' as the intersection of shape1 and shape2. 439 ShapeHandle OutputAsIntersection(const Node* node, int port_index, 440 ShapeHandle shape1, ShapeHandle shape2) { 441 if (shape1.SameHandle(shape2)) { 442 return shape1; 443 } 444 InferenceContext* ctx = shape_refiner_->GetContext(node); 445 ShapeHandle merged = shape1; 446 if (!ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) { 447 // Return either one since they're expected to represent the same value. 448 return shape1; 449 } else if (!ctx->RankKnown(shape2) && ctx->RankKnown(shape1)) { 450 return shape1; 451 } else if (ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) { 452 return shape2; 453 } else { 454 const int rank = ctx->Rank(shape1); 455 if (ctx->Rank(shape2) != rank) { 456 // We detected an inconsistency, return an unknown shape. This can 457 // happen in the fanout of a merge node since during the initial 458 // propagation we optimistically assume that all the inputs to the merge 459 // node have the same shape. 460 return GetUnknownOutputShape(node, port_index); 461 } 462 for (int d = 0; d < rank; ++d) { 463 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) { 464 if (ctx->Value(ctx->Dim(shape1, d)) != 465 ctx->Value(ctx->Dim(shape2, d))) { 466 DimensionHandle new_dim; 467 if (ctx->Value(ctx->Dim(shape1, d)) < 0) { 468 new_dim = ctx->Dim(shape2, d); 469 } else if (ctx->Value(ctx->Dim(shape2, d)) < 0) { 470 new_dim = ctx->Dim(shape1, d); 471 } else { 472 new_dim = GetUnknownOutputDim(node, port_index, d); 473 } 474 TF_CHECK_OK(ctx->ReplaceDim(merged, d, new_dim, &merged)); 475 } 476 } 477 } 478 } 479 return merged; 480 } 481 482 // Compute the shape of the tensors outputed by node 'node' at output port 483 // 'port_index' as the union of shape1 and shape2. 484 ShapeHandle OutputAsUnion(const Node* node, int port_index, 485 ShapeHandle shape1, ShapeHandle shape2) { 486 if (shape1.SameHandle(shape2)) { 487 return shape1; 488 } 489 InferenceContext* ctx = shape_refiner_->GetContext(node); 490 ShapeHandle relaxed = shape1; 491 const int rank = ctx->Rank(shape1); 492 if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) { 493 relaxed = GetUnknownOutputShape(node, port_index); 494 } else { 495 for (int d = 0; d < rank; ++d) { 496 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) { 497 int64 val1 = ctx->Value(ctx->Dim(shape1, d)); 498 int64 val2 = ctx->Value(ctx->Dim(shape2, d)); 499 if (val1 != val2 || (val1 < 0 && val2 < 0)) { 500 DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d); 501 TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed)); 502 } 503 } 504 } 505 } 506 return relaxed; 507 } 508 509 bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const { 510 if (s1.SameHandle(s2)) { 511 return true; 512 } 513 if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) { 514 return false; 515 } 516 if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) { 517 return true; 518 } 519 const int rank = InferenceContext::Rank(s1); 520 for (int i = 0; i < rank; ++i) { 521 if (!InferenceContext::DimKnownRank(s1, i).SameHandle( 522 InferenceContext::DimKnownRank(s2, i))) { 523 int64 val1 = 524 InferenceContext::Value(InferenceContext::DimKnownRank(s1, i)); 525 int64 val2 = 526 InferenceContext::Value(InferenceContext::DimKnownRank(s2, i)); 527 if (val1 >= 0 && val2 >= 0 && val1 == val2) { 528 continue; 529 } 530 return false; 531 } 532 } 533 return true; 534 } 535 536 bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1, 537 const std::vector<ShapeAndType>& st2) const { 538 if (st1.size() != st2.size()) { 539 return false; 540 } 541 for (int i = 0; i < st1.size(); ++i) { 542 const ShapeAndType& s1 = st1[i]; 543 const ShapeAndType& s2 = st2[i]; 544 if (s1.dtype != s2.dtype) { 545 return false; 546 } 547 if (!EquivalentShapes(s1.shape, s2.shape)) { 548 return false; 549 } 550 } 551 return true; 552 } 553 554 private: 555 // Return the one ShapeHandle used to denote a fully unknown shape for a node 556 // output. 557 ShapeHandle GetUnknownOutputShape(const Node* node, int index) { 558 ShapeId id{node, index}; 559 auto it = unknown_shapes_.find(id); 560 if (it != unknown_shapes_.end()) { 561 return it->second; 562 } 563 InferenceContext* c = shape_refiner_->GetContext(node); 564 ShapeHandle shp = c->UnknownShape(); 565 unknown_shapes_[id] = shp; 566 return shp; 567 } 568 // Return the one ShapeHandle used to denote a fully unknown dimension for a 569 // node output. 570 DimensionHandle GetUnknownOutputDim(const Node* node, int index, int dim_id) { 571 DimId id{node, index, dim_id}; 572 auto it = unknown_dims_.find(id); 573 if (it != unknown_dims_.end()) { 574 return it->second; 575 } 576 InferenceContext* c = shape_refiner_->GetContext(node); 577 DimensionHandle dim = c->UnknownDim(); 578 unknown_dims_[id] = dim; 579 return dim; 580 } 581 582 ShapeRefiner* shape_refiner_; 583 584 std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_; 585 std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_; 586 }; 587 588 // Keep track of shapes and dimensions in a graph. 589 // In particular, use disjoint sets to track equivalence between shapes and 590 // dims, and consolidate the information globally. 591 class SymbolicShapeManager { 592 public: 593 SymbolicShapeManager() : shapes_(shape_processor_), dims_(dim_processor_) {} 594 595 Status Merge(ShapeHandle s1, ShapeHandle s2) { 596 if (!s1.IsSet() || !s2.IsSet()) { 597 return Status::OK(); 598 } 599 TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2)); 600 if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) { 601 CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2)); 602 for (int i = 0; i < InferenceContext::Rank(s1); ++i) { 603 TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i), 604 InferenceContext::DimKnownRank(s2, i))); 605 } 606 } 607 return Status::OK(); 608 } 609 Status Merge(DimensionHandle d1, DimensionHandle d2) { 610 if (!d1.IsSet() || !d2.IsSet()) { 611 return Status::OK(); 612 } 613 return dims_.Merge(d1, d2); 614 } 615 616 void AsTensorProperties(const ShapeHandle& shape, const DataType& type, 617 OpInfo::TensorProperties* properties) { 618 properties->set_dtype(type); 619 ShapeHandle actual_shape = shapes_.GetMergedValue(shape); 620 if (!InferenceContext::RankKnown(actual_shape)) { 621 properties->mutable_shape()->set_unknown_rank(true); 622 } else { 623 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) { 624 shape_inference::DimensionHandle dim = 625 InferenceContext::DimKnownRank(actual_shape, j); 626 int64 d = dims_.GetMergedValue(dim); 627 properties->mutable_shape()->add_dim()->set_size(d); 628 } 629 } 630 } 631 632 private: 633 Processor<ShapeHandle> shape_processor_; 634 DisjointSet<shape_inference::ShapeHandle> shapes_; 635 Processor<DimensionHandle> dim_processor_; 636 DisjointSet<shape_inference::DimensionHandle> dims_; 637 }; 638 639 Status GraphProperties::MergeEnqueueShapesAndTypes( 640 SymbolicShapeRefiner* shape_refiner, const Node* qnode, 641 const std::vector<ShapeAndType>& shapes_and_types, 642 std::vector<ShapeAndType>* queue_shapes_and_types) { 643 if (shapes_and_types.size() != queue_shapes_and_types->size()) { 644 return errors::InvalidArgument( 645 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(), 646 " vs ", queue_shapes_and_types->size()); 647 } 648 for (size_t i = 0; i < shapes_and_types.size(); ++i) { 649 const ShapeAndType& a = shapes_and_types[i]; 650 ShapeAndType& b = (*queue_shapes_and_types)[i]; 651 if (a.dtype != b.dtype) { 652 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ", 653 i, ": ", DataTypeString(a.dtype), " vs ", 654 DataTypeString(b.dtype)); 655 } 656 657 b.shape = shape_refiner->OutputAsIntersection(qnode, i, a.shape, b.shape); 658 } 659 return Status::OK(); 660 } 661 662 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( 663 SymbolicShapeRefiner* shape_refiner, const Node* qnode, 664 const std::vector<ShapeAndType>& shapes_and_types, 665 std::vector<ShapeAndType>* queue_shapes_and_types) { 666 if (shapes_and_types.size() != queue_shapes_and_types->size()) { 667 return errors::InvalidArgument( 668 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(), 669 " vs ", queue_shapes_and_types->size()); 670 } 671 for (size_t i = 0; i < shapes_and_types.size(); ++i) { 672 const ShapeAndType& a = shapes_and_types[i]; 673 ShapeAndType& b = (*queue_shapes_and_types)[i]; 674 if (a.dtype != b.dtype) { 675 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ", 676 i, ": ", DataTypeString(a.dtype), " vs ", 677 DataTypeString(b.dtype)); 678 } 679 680 b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape); 681 } 682 return Status::OK(); 683 } 684 685 // If a Merge node has a NextIteration node as an input then that input will 686 // try to forward an UnknownShape at graph construction time. However, the 687 // Merge shape function will always propagate an UnknownShape if any of its 688 // inputs are UnknownShapes. So we need to ignore the input from NextIteration 689 // nodes to propagate any known shape from the Merge node. 690 Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, 691 const Node* node, bool relax, 692 TopoQueue* new_shapes) { 693 InferenceContext* c = shape_refiner->GetContext(node); 694 CHECK_NE(c, nullptr); 695 696 ShapeHandle out1; 697 TF_RETURN_IF_ERROR(c->WithRank(c->output(1), 0, &out1)); 698 c->set_output(1, out1); 699 700 ShapeHandle out; 701 bool out_initialized = false; 702 for (const Edge* e : node->in_edges()) { 703 if (e->IsControlEdge()) { 704 continue; 705 } 706 // Skip back edges during the initial propagation phase. This is equivalent 707 // to assuming that all the inputs to the merge nodes are fed by the same 708 // shape, and will be corrected as needed in the relaxation phase. 709 if (!relax && e->src()->IsNextIteration()) { 710 continue; 711 } 712 713 InferenceContext* in = shape_refiner->GetContext(e->src()); 714 ShapeHandle input = in->output(e->src_output()); 715 if (relax) { 716 c->RelaxInput(e->dst_input(), input); 717 } else { 718 c->MergeInput(e->dst_input(), input); 719 } 720 if (!out_initialized) { 721 out_initialized = true; 722 out = input; 723 continue; 724 } 725 if (relax) { 726 out = shape_refiner->OutputAsUnion(node, 0, input, out); 727 } else { 728 out = shape_refiner->OutputAsIntersection(node, 0, input, out); 729 } 730 } 731 732 if (!shape_refiner->EquivalentShapes(out, c->output(0))) { 733 c->set_output(0, out); 734 new_shapes->push(node); 735 } 736 737 return Status::OK(); 738 } 739 740 Status GraphProperties::OverwriteFedPorts( 741 SymbolicShapeRefiner* shape_refiner, 742 const std::unordered_map<string, std::unordered_set<int>>& fed_ports, 743 const Node* node, TopoQueue* new_shapes) const { 744 auto it = fed_ports.find(node->name()); 745 Status status; 746 if (it != fed_ports.end()) { 747 // It is possible to feed node output ports with tensors of any shape: as a 748 // result, the shape of a fed port is completely unknown. 749 for (const int output_port : it->second) { 750 status.Update(shape_refiner->SetUnknownShape(node, output_port)); 751 } 752 new_shapes->push(node); 753 } 754 return status; 755 } 756 757 // Manually propagate the input shape for Enter nodes and update any Merge node 758 // outputs. 759 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, 760 const Node* node, bool relax, 761 TopoQueue* new_shapes) { 762 auto enter_ctx = shape_refiner->GetContext(node); 763 CHECK_NE(enter_ctx, nullptr); 764 765 for (const Edge* e : node->in_edges()) { 766 if (e->IsControlEdge()) { 767 continue; 768 } 769 InferenceContext* in = shape_refiner->GetContext(e->src()); 770 ShapeHandle input = in->output(e->src_output()); 771 if (!enter_ctx->output(0).SameHandle(input)) { 772 if (relax) { 773 enter_ctx->RelaxInput(0, input); 774 } else { 775 enter_ctx->MergeInput(0, input); 776 } 777 enter_ctx->set_output(0, input); 778 new_shapes->push(node); 779 } 780 } 781 return Status::OK(); 782 } 783 784 Status GraphProperties::UpdateShapes( 785 SymbolicShapeRefiner* shape_refiner, bool relax, 786 const std::unordered_map<string, std::unordered_set<int>>& fed_ports, 787 const Node* n, TopoQueue* new_shapes) const { 788 if (n->IsEnter()) { 789 // The Enter shape function always forwards an UnknownShape, so do the right 790 // thing here. 791 TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, relax, new_shapes)); 792 } else if (n->IsMerge()) { 793 // Properly handle merge nodes. 794 TF_RETURN_IF_ERROR(UpdateMergeNode(shape_refiner, n, relax, new_shapes)); 795 } else { 796 // Rely on regular TF shape refinement for all the other nodes. 797 bool updated = false; 798 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, relax, &updated)); 799 if (updated) { 800 // We want to avoid propagating through loops on the merge pass because 801 // the shapes are not guaranteed to converge. 802 if (relax || !n->IsNextIteration()) { 803 new_shapes->push(n); 804 } 805 } 806 } 807 // Nodes can be fed with any shape. The TensorFlow shape inference code can't 808 // handle this properly, so overwrite its behavior here. 809 return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes); 810 } 811 812 // Propagates the shapes in the transitive fan-out of <new_shapes>. 813 Status GraphProperties::PropagateShapes( 814 SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, 815 const std::unordered_map<const Node*, std::unordered_set<const Node*>>& 816 resources, 817 const std::unordered_map<string, std::unordered_set<int>>& fed_ports, 818 int num_loops) const { 819 // Limit the number of iterations to prevent infinite loops in the presence of 820 // incorrect shape functions. The algoritm should converge in at most 821 // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4. 822 // The same applies to resources. 823 VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size() 824 << " new shapes through " << num_loops << " loops and " 825 << resources.size() << " resources" << std::endl; 826 827 const int64 max_loop_length = item_.graph.node_size(); 828 const int64 max_rank = 4; 829 const int64 max_loop_iterations = 830 max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops); 831 const int64 num_queues = resources.size(); 832 const int64 max_resource_iterations = num_queues * num_queues * max_rank; 833 834 int64 num_resource_iterations = 0; 835 do { 836 int64 num_loop_iterations = 0; 837 while (!new_shapes->empty() && 838 num_loop_iterations++ < max_loop_iterations) { 839 const Node* n = new_shapes->pop(); 840 for (const Edge* e : n->out_edges()) { 841 if (!e->IsControlEdge()) { 842 const Node* fanout = e->dst(); 843 TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports, 844 fanout, new_shapes)); 845 } 846 } 847 } 848 849 for (const auto& resource : resources) { 850 // Resources need special handling: since the enqueue nodes are in the 851 // fanout of the queues, we need to manually propagate the shapes from 852 // enqueue node to the corresponding queue. 853 TF_RETURN_IF_ERROR(UpdateResource(resource.first, resource.second, 854 shape_refiner, relax, new_shapes)); 855 } 856 } while (!new_shapes->empty() && 857 num_resource_iterations++ < max_resource_iterations); 858 859 if (!new_shapes->empty()) { 860 return errors::Internal("Shape inference failed to converge"); 861 } 862 863 return Status::OK(); 864 } 865 866 Status GraphProperties::UpdateResource( 867 const Node* qnode, const std::unordered_set<const Node*>& queue_inputs, 868 SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes) { 869 // Proceed only if qnode is a queue or an Enter with queue input. 870 if (!IsQueue(*qnode) && !IsEnterWithQueue(*qnode)) { 871 return Status::OK(); 872 } 873 auto qctx = shape_refiner->GetContext(qnode); 874 if (!qctx) { 875 return Status::OK(); 876 } 877 auto* queue_handle_data = qctx->output_handle_shapes_and_types(0); 878 879 // Merge all inputs into the enqueue node, regardless of which phase we 880 // are in. 881 std::vector<ShapeAndType> queue_shapes_and_types; 882 if (queue_handle_data) { 883 queue_shapes_and_types = *queue_handle_data; 884 } 885 for (const auto& node : queue_inputs) { 886 auto ctx = shape_refiner->GetContext(node); 887 if (!ctx) { 888 continue; 889 } 890 // TODO(bsteiner): handle EnqueueMany as well. 891 if (node->type_string().find("Enqueue") != std::string::npos && 892 node->type_string().find("EnqueueMany") == std::string::npos) { 893 std::vector<ShapeAndType> shapes_and_types; 894 for (int i = 1; i < ctx->num_inputs(); ++i) { 895 shapes_and_types.push_back({ctx->input(i), node->input_type(i)}); 896 } 897 if (queue_shapes_and_types.empty()) { 898 queue_shapes_and_types = shapes_and_types; 899 } else { 900 if (relax) { 901 TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes( 902 shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types)); 903 } else { 904 TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes( 905 shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types)); 906 } 907 } 908 } 909 } 910 911 if (queue_handle_data == nullptr || 912 !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data, 913 queue_shapes_and_types)) { 914 qctx->set_output_handle_shapes_and_types(0, queue_shapes_and_types); 915 916 new_shapes->push(qnode); 917 } 918 919 return Status::OK(); 920 } 921 922 Status GraphProperties::InferStatically(bool assume_valid_feeds) { 923 Graph graph(OpRegistry::Global()); 924 FunctionLibraryDefinition function_library(graph.op_registry(), 925 item_.graph.library()); 926 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 927 shape_refiner.set_require_shape_inference_fns(false); 928 shape_refiner.set_disable_constant_propagation(true); 929 shape_refiner.set_function_library_for_shape_inference(&function_library); 930 ImportGraphDefOptions options; 931 // Graph optimization happens at the late stage of graph execution, 932 // when colocation constraints are already validated previously and 933 // the device placement of nodes has also completed, so there 934 // is no need to validate colocation constraints again. 935 options.validate_colocation_constraints = false; 936 Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); 937 TF_RETURN_IF_ERROR(s); 938 939 std::unordered_map<string, std::unordered_set<int>> fed_ports; 940 if (!assume_valid_feeds) { 941 for (const auto& feed : item_.feed) { 942 int port_index = 0; 943 string node_name = ParseNodeName(feed.first, &port_index); 944 fed_ports[node_name].insert(port_index); 945 } 946 } 947 948 // List the resources and the nodes using them. Also collect the Enter and 949 // Merge nodes. 950 std::unordered_map<const Node*, std::unordered_set<const Node*>> resources; 951 std::unordered_set<const Node*> enter_nodes; 952 std::unordered_set<const Node*> merge_nodes; 953 std::unordered_set<const Node*> fed_nodes; 954 int num_loops = 0; 955 for (const Node* const node : graph.nodes()) { 956 for (int i = 0; i < node->num_inputs(); ++i) { 957 if (node->input_type(i) == DataType::DT_RESOURCE) { 958 const Node* resource; 959 TF_CHECK_OK(node->input_node(i, &resource)); 960 resources[resource].insert(node); 961 } 962 } 963 if (node->IsEnter()) { 964 enter_nodes.insert(node); 965 } else if (node->IsMerge()) { 966 merge_nodes.insert(node); 967 } else if (node->IsNextIteration()) { 968 ++num_loops; 969 } 970 if (fed_ports.find(node->name()) != fed_ports.end()) { 971 fed_nodes.insert(node); 972 } 973 } 974 975 SymbolicShapeRefiner refiner(&shape_refiner); 976 977 // We propagate shapes through the graph in two phases. In the first phase, we 978 // exclusively merge shapes but we do not propagate shapes through the 979 // backedge of loops (i.e. the NextIteration node). Then on the second phase, 980 // we exclusively relax shapes and propagate shapes through loops until 981 // reaching fixed point. 982 for (int relax = 0; relax < 2; relax++) { 983 TopoQueue new_shapes; 984 // Force the propagation of shapes of Enter nodes manually (the Enter shape 985 // function always forwards an UnknownShape). 986 for (const Node* node : enter_nodes) { 987 TF_RETURN_IF_ERROR( 988 UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); 989 } 990 // Seed the propagation of shapes through merge nodes. 991 for (const Node* node : merge_nodes) { 992 TF_RETURN_IF_ERROR( 993 UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); 994 } 995 // Also seed the propagation of shapes in the fanout of fed nodes. 996 for (const Node* node : fed_nodes) { 997 TF_RETURN_IF_ERROR( 998 OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes)); 999 } 1000 // Propagate shapes normally. 1001 TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources, 1002 fed_ports, num_loops)); 1003 } 1004 1005 // Track shapes globally across the graph. 1006 SymbolicShapeManager shape_manager; 1007 bool found_error = false; 1008 for (const Node* const node : graph.nodes()) { 1009 auto node_ctx = shape_refiner.GetContext(node); 1010 if (!node_ctx) { 1011 continue; 1012 } 1013 // Skip any information that comes from fed nodes. 1014 if (fed_ports.find(node->name()) != fed_ports.end()) { 1015 continue; 1016 } 1017 for (const auto& merged_shapes : node_ctx->MergedShapes()) { 1018 if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second) 1019 .ok()) { 1020 found_error = true; 1021 break; 1022 } 1023 } 1024 for (const auto& merged_dims : node_ctx->MergedDims()) { 1025 if (!shape_manager.Merge(merged_dims.first, merged_dims.second).ok()) { 1026 found_error = true; 1027 break; 1028 } 1029 } 1030 if (found_error) { 1031 // The shapes aren't consistent, we can't infer safely: discard all the 1032 // information discovered so far. 1033 shape_manager = SymbolicShapeManager(); 1034 break; 1035 } 1036 } 1037 1038 for (const Node* const node : graph.nodes()) { 1039 VLOG(3) << "Filling in graph properties for node: " << node->name(); 1040 auto ctx = shape_refiner.GetContext(node); 1041 if (!ctx) { 1042 continue; 1043 } 1044 1045 // Fill input properties. 1046 { 1047 CHECK_EQ(ctx->num_inputs(), node->num_inputs()); 1048 auto& input_properties = input_properties_[node->name()]; 1049 1050 // Should always be empty, node names in graph are supposed to be unique. 1051 CHECK_EQ(input_properties.size(), 0); 1052 1053 input_properties.resize(ctx->num_inputs()); 1054 for (int i = 0; i < ctx->num_inputs(); ++i) { 1055 shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i), 1056 &input_properties[i]); 1057 } 1058 for (const auto& edge : node->in_edges()) { 1059 if (edge->IsControlEdge()) { 1060 continue; 1061 } 1062 if (!edge->src()->IsConstant()) { 1063 continue; 1064 } 1065 const int input_id = edge->dst_input(); 1066 if (input_id >= input_properties.size()) { 1067 continue; 1068 } 1069 const NodeDef& node = edge->src()->def(); 1070 const TensorProto& raw_val = node.attr().at("value").tensor(); 1071 *input_properties[input_id].mutable_value() = raw_val; 1072 } 1073 } 1074 1075 // Fill output properties. 1076 { 1077 CHECK_EQ(ctx->num_outputs(), node->num_outputs()); 1078 auto& output_properties = output_properties_[node->name()]; 1079 1080 // Should always be empty, node names in graph are supposed to be unique. 1081 CHECK_EQ(output_properties.size(), 0); 1082 1083 output_properties.resize(ctx->num_outputs()); 1084 for (int i = 0; i < ctx->num_outputs(); ++i) { 1085 shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i), 1086 &output_properties[i]); 1087 } 1088 } 1089 } 1090 1091 // Help trace the unknown dimensions to their origins. 1092 VerboseLogUnknownDimensionSources(graph, input_properties_, 1093 output_properties_); 1094 1095 return Status::OK(); 1096 } 1097 1098 Status GraphProperties::InferDynamically(Cluster* cluster) { 1099 TF_RETURN_IF_ERROR(cluster->Initialize(item_)); 1100 1101 // Runs the model once to collect the shapes in the cost model. 1102 RunMetadata metadata; 1103 TF_RETURN_IF_ERROR( 1104 cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata)); 1105 1106 return InferFromCostGraph(metadata.cost_graph()); 1107 } 1108 1109 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const { 1110 *output_graph_def = item_.graph; 1111 for (int i = 0; i < output_graph_def->node_size(); i++) { 1112 auto node = output_graph_def->mutable_node(i); 1113 AttrValue attr_output_shape; 1114 auto tensor_properties = GetOutputProperties(node->name()); 1115 for (const auto& tensor_property : tensor_properties) { 1116 *attr_output_shape.mutable_list()->add_shape() = tensor_property.shape(); 1117 } 1118 (*node->mutable_attr())["_output_shapes"] = attr_output_shape; 1119 } 1120 return Status::OK(); 1121 } 1122 1123 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { 1124 if (cost_graph.node_size() == 0) { 1125 LOG(WARNING) << "cost_graph is empty: nothing can be inferred!"; 1126 } 1127 std::unordered_map<string, const CostGraphDef::Node*> name_to_cost; 1128 std::unordered_map<string, const NodeDef*> name_to_node; // Empty 1129 for (auto& node : cost_graph.node()) { 1130 name_to_cost[node.name()] = &node; 1131 1132 std::vector<OpInfo::TensorProperties> output_properties; 1133 for (const auto& out : node.output_info()) { 1134 OpInfo::TensorProperties properties; 1135 properties.set_dtype(out.dtype()); 1136 *properties.mutable_shape() = out.shape(); 1137 output_properties.push_back(properties); 1138 } 1139 output_properties_[node.name()] = output_properties; 1140 } 1141 1142 for (const auto& node : item_.graph.node()) { 1143 // Skip the nodes that are not in the cost graph: these are nodes that 1144 // aren't run, because they aren't in the intersection of transitive fan-in 1145 // of a fetch node and the transitive fan-out of an input, or nodes that 1146 // were optimized away by the optimizer. 1147 auto it = name_to_cost.find(node.name()); 1148 if (it == name_to_cost.end()) { 1149 continue; 1150 } 1151 std::vector<OpInfo::TensorProperties> inputs = 1152 FindInputFeatures(node, name_to_cost, name_to_node); 1153 1154 input_properties_[node.name()] = inputs; 1155 } 1156 return Status::OK(); 1157 } 1158 1159 bool GraphProperties::HasInputProperties(const string& name) const { 1160 return input_properties_.find(name) != input_properties_.end(); 1161 } 1162 1163 bool GraphProperties::HasOutputProperties(const string& name) const { 1164 return output_properties_.find(name) != output_properties_.end(); 1165 } 1166 1167 const std::vector<OpInfo::TensorProperties>& 1168 GraphProperties::GetInputProperties(const string& node_name) const { 1169 auto it = input_properties_.find(node_name); 1170 if (it != input_properties_.end()) { 1171 return it->second; 1172 } 1173 return missing_properties_; 1174 } 1175 1176 const std::vector<OpInfo::TensorProperties>& 1177 GraphProperties::GetOutputProperties(const string& node_name) const { 1178 auto it = output_properties_.find(node_name); 1179 if (it != output_properties_.end()) { 1180 return it->second; 1181 } 1182 return missing_properties_; 1183 } 1184 1185 } // end namespace grappler 1186 } // end namespace tensorflow 1187