Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/shape_refiner.h"
     16 
     17 #include <deque>
     18 #include <memory>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/common_shape_fns.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor.pb.h"
     26 #include "tensorflow/core/framework/versions.pb.h"
     27 #include "tensorflow/core/graph/algorithm.h"
     28 #include "tensorflow/core/graph/graph_constructor.h"
     29 #include "tensorflow/core/kernels/bounds_check.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/gtl/stl_util.h"
     32 #include "tensorflow/core/public/session.h"
     33 
     34 namespace tensorflow {
     35 
     36 using shape_inference::DimensionHandle;
     37 using shape_inference::InferenceContext;
     38 using shape_inference::ShapeAndType;
     39 using shape_inference::ShapeHandle;
     40 
     41 ShapeRefiner::ShapeRefiner(int graph_def_version,
     42                            const OpRegistryInterface* ops)
     43     : graph_def_version_(graph_def_version),
     44       ops_registry_(ops),
     45       graph_runner_(Env::Default()) {}
     46 
     47 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
     48                            const OpRegistryInterface* ops)
     49     : ShapeRefiner(versions.producer(), ops) {}
     50 
     51 ShapeRefiner::~ShapeRefiner() {
     52   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
     53   // should be deleted before it.
     54   const_tensor_map_.clear();
     55 }
     56 
     57 namespace {
     58 
     59 constexpr char kArgOp[] = "_Arg";
     60 constexpr char kRetvalOp[] = "_Retval";
     61 
     62 // Runs shape inference for the given node using the given ShapeRefiner.
     63 // The node must be a sub-node of a function node and the outer_context is
     64 // the inference context of that function node in the outer graph.
     65 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
     66                                      InferenceContext* outer_context) {
     67   TF_RETURN_IF_ERROR(refiner->AddNode(node));
     68   InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
     69 
     70   if (StringPiece(node->type_string()) == kArgOp) {
     71     // Handle special node: function input.
     72     // Shapes for these nodes are provided in the outer inference
     73     // context.
     74 
     75     int index;
     76     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
     77 
     78     if (index < 0 || outer_context->num_inputs() <= index) {
     79       return errors::Internal(
     80           "Function instantiation included invalid input index: ", index,
     81           " not in [0, ", outer_context->num_inputs(), ").");
     82     }
     83 
     84     node_context->set_output(0, outer_context->input(index));
     85 
     86     auto* resource = outer_context->input_handle_shapes_and_types(index);
     87     if (resource) {
     88       node_context->set_output_handle_shapes_and_types(0, *resource);
     89     }
     90   } else if (StringPiece(node->type_string()) == kRetvalOp) {
     91     // Handle special node: function output.
     92     // Shapes inferred for these nodes go into the outer inference
     93     // context.
     94 
     95     int index;
     96     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
     97 
     98     if (index < 0 || outer_context->num_outputs() <= index) {
     99       return errors::Internal(
    100           "Function instantiation included invalid output index: ", index,
    101           " not in [0, ", outer_context->num_outputs(), ").");
    102     }
    103 
    104     // outer_context outlives node_context, therefore we need to create
    105     // a new shape handle owned by outer_context instead.
    106     ShapeHandle handle;
    107     TensorShapeProto proto;
    108     node_context->ShapeHandleToProto(node_context->input(0), &proto);
    109     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
    110     outer_context->set_output(index, handle);
    111 
    112     auto* resource = node_context->input_handle_shapes_and_types(0);
    113     if (resource) {
    114       outer_context->set_output_handle_shapes_and_types(index, *resource);
    115     }
    116   }
    117 
    118   return Status::OK();
    119 }
    120 
    121 }  // namespace
    122 
    123 // TODO(cwhipkey): When an inference context inside function has
    124 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
    125 // set when input(i) is an _Arg op, then this request should propagate to
    126 // context, and vice versa.
    127 //
    128 // NOTE: Recursive user-defined functions are not supported.
    129 // Maybe we won't support recursive functions at all in TF, because of
    130 // other maintainability issues.
    131 Status ShapeRefiner::InferShapesForFunction(
    132     const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
    133     ExtendedInferenceContext* outer_context) {
    134   const Graph* graph;
    135   auto it = functions_.find(function_def);
    136   if (it != functions_.end()) {
    137     graph = it->second.get();
    138   } else {
    139     InstantiationResult result;
    140     TF_RETURN_IF_ERROR(InstantiateFunction(
    141         *function_def, outer_context->get_context()->attrs(),
    142         [this](const string& op, const OpDef** sig) {
    143           return this->function_library_->LookUpOpDef(op, sig);
    144         },
    145         &result));
    146 
    147     Graph* new_graph = new Graph(function_library_);
    148     GraphConstructorOptions options;
    149     options.allow_internal_ops = true;
    150     TF_RETURN_IF_ERROR(
    151         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
    152     functions_[function_def].reset(new_graph);
    153     graph = new_graph;
    154   }
    155 
    156   std::unordered_set<const Node*> function_nodes;
    157   Status inference_status = Status::OK();
    158   {
    159     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
    160                                         &inference_status](const Node* node) {
    161       if (!inference_status.ok()) return;
    162       inference_status = InferShapesForFunctionSubNode(
    163           node, this, outer_context->get_context());
    164       function_nodes.insert(node);
    165     };
    166 
    167     // Calls inference lambda for each node after visiting all predecessors.
    168     // Ensures that we are adding nodes to ShapeRefiner in the topological
    169     // order.
    170     ReverseDFS(*graph, {}, node_shape_inference_lambda);
    171   }
    172 
    173   if (keep_nested_shapes && inference_status.ok()) {
    174     // Fill the nested inferences map.
    175     //
    176     // The materialized function graph has extra nodes for arguments and
    177     // return values, which are not explicitly listed in the FunctionDef,
    178     // we filter out these special nodes here to not expose the implementation
    179     // details and keep only inferences for the nodes listed in the FunctionDef.
    180     std::unordered_map<string, const NodeDef*> user_defined_nodes;
    181     for (const auto& node_def : function_def->node_def()) {
    182       user_defined_nodes[node_def.name()] = &node_def;
    183     }
    184 
    185     std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
    186         nested_inferences;
    187     for (const Node* node : function_nodes) {
    188       const string& node_name = node->name();
    189       if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
    190         nested_inferences[node_name] = std::move(node_to_context_[node]);
    191         node_to_context_.erase(node);
    192         // By default InferenceContext refers to a NodeDef from Graph.
    193         // Change it to the publicly accessible NodeDef of the function
    194         // definition.
    195         nested_inferences[node_name]->get_context()->node_def_ =
    196             user_defined_nodes[node_name];
    197       }
    198     }
    199     outer_context->set_nested_inferences(std::move(nested_inferences));
    200   } else {
    201     // Delete the contexts created for the functions nodes to save memory.
    202     for (const Node* node : function_nodes) {
    203       node_to_context_.erase(node);
    204     }
    205   }
    206 
    207   return inference_status;
    208 }
    209 
    210 Status ShapeRefiner::AddNode(const Node* node) {
    211   // For each 'input' of this node, fetch the corresponding shape
    212   // from 'input's InferenceContext, and store into a vector
    213   // indexed by 'node's input.
    214   std::vector<Node*> input_nodes(node->num_inputs());
    215   std::vector<ShapeHandle> input_shapes(node->num_inputs());
    216   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    217       input_handle_shapes_and_types(node->num_inputs());
    218   for (const Edge* e : node->in_edges()) {
    219     if (e->IsControlEdge()) continue;
    220 
    221     Node* input = e->src();
    222     auto it = node_to_context_.find(input);
    223     if (it == node_to_context_.end()) {
    224       return errors::FailedPrecondition(
    225           "Input ", e->dst_input(), " ('", input->name(), "') for '",
    226           node->name(), "' was not previously added to ShapeRefiner.");
    227     }
    228 
    229     InferenceContext* c = it->second->get_context();
    230     DCHECK_GE(e->dst_input(), 0);
    231     input_nodes[e->dst_input()] = input;
    232     input_shapes[e->dst_input()] = c->output(e->src_output());
    233 
    234     // Only propagate handle data of edges which are carrying resource handles.
    235     if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
    236       const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
    237       if (in_v != nullptr) {
    238         input_handle_shapes_and_types[e->dst_input()].reset(
    239             new std::vector<ShapeAndType>(*in_v));
    240       }
    241     }
    242   }
    243 
    244   // Get the shape function for this node
    245   const OpRegistrationData* op_reg_data;
    246   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
    247   if (op_reg_data->shape_inference_fn == nullptr &&
    248       require_shape_inference_fns_) {
    249     return errors::InvalidArgument(
    250         "No shape inference function exists for op '", node->type_string(),
    251         "', did you forget to define it?");
    252   }
    253 
    254   // This needs to be filled in with real data in a second pass.
    255   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
    256   std::vector<ShapeHandle> input_tensors_as_shapes;
    257 
    258   // Create the inference context for this node with the existing input shapes.
    259   std::unique_ptr<InferenceContext> c(
    260       new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
    261                            input_shapes, input_tensors, input_tensors_as_shapes,
    262                            std::move(input_handle_shapes_and_types)));
    263   if (!c->construction_status().ok()) {
    264     return c->construction_status();
    265   }
    266 
    267   std::unique_ptr<ExtendedInferenceContext> ec(
    268       new ExtendedInferenceContext(std::move(c), node));
    269 
    270   // Run the shape inference function, and return if there was an error.
    271   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
    272 
    273   // Store the resulting context object in the map.
    274   node_to_context_[node].swap(ec);
    275 
    276   return Status::OK();
    277 }
    278 
    279 Status ShapeRefiner::SetShape(const Node* node, int output_port,
    280                               ShapeHandle shape) {
    281   auto c = GetContext(node);
    282   if (c == nullptr) {
    283     return errors::Internal("Could not find context for ", node->name());
    284   }
    285 
    286   if (output_port < 0 || output_port >= node->num_outputs()) {
    287     return errors::InvalidArgument(
    288         "output_port '", output_port, "' is out of range, ", "node '",
    289         node->name(), "' has ", node->num_outputs(), " outputs");
    290   }
    291 
    292   // Check compatibility, and merge the shapes.
    293   ShapeHandle existing_shape = c->output(output_port);
    294   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
    295   c->set_output(output_port, shape);
    296 
    297   // TODO(vrv): Do we need to propagate the new shape through all
    298   // consumers that change their outputs?  At the moment, python
    299   // does not do this, but this seems like a nice feature.
    300 
    301   // TODO(vrv): We might need to keep track of the fact that the
    302   // existing shape is invalidated, in case we need to propagate
    303   // this information to remote workers.
    304   return Status::OK();
    305 }
    306 
    307 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
    308   auto it = node_to_context_.find(node);
    309   if (it == node_to_context_.end()) {
    310     *refined = true;
    311     return AddNode(node);
    312   }
    313   ExtendedInferenceContext* node_ext_context = it->second.get();
    314   InferenceContext* node_context = node_ext_context->get_context();
    315 
    316   // Give up if the context wasn't successfully built by the AddNode() method.
    317   TF_RETURN_IF_ERROR(node_context->construction_status());
    318 
    319   // Check if the shapes of the nodes in the fan-in of this node have changed,
    320   // and if they have update the node input shapes.
    321   for (const Edge* e : node->in_edges()) {
    322     if (e->IsControlEdge()) continue;
    323 
    324     int dst_input = e->dst_input();
    325     int src_output = e->src_output();
    326 
    327     Node* input = e->src();
    328     auto iter = node_to_context_.find(input);
    329     if (iter == node_to_context_.end()) {
    330       return errors::FailedPrecondition(
    331           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
    332           "' was not previously added to ShapeRefiner.");
    333     }
    334 
    335     InferenceContext* c = iter->second->get_context();
    336     DCHECK_GE(dst_input, 0);
    337     ShapeHandle existing_input = node_context->input(dst_input);
    338     if (!relax) {
    339       if (node_context->MergeInput(dst_input, c->output(src_output))) {
    340         if (!SameDefinedShape(node_context, node_context->input(dst_input),
    341                               existing_input)) {
    342           *refined = true;
    343         }
    344       }
    345     } else {
    346       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
    347         if (!SameDefinedShape(node_context, node_context->input(dst_input),
    348                               existing_input)) {
    349           *refined = true;
    350         }
    351       }
    352     }
    353 
    354     // Also propagate handle shape and dtype of edges which are carrying
    355     // resource handles.
    356     if (e->src()->output_type(src_output) == DT_RESOURCE) {
    357       auto* outputs = c->output_handle_shapes_and_types(src_output);
    358       if (!outputs) continue;
    359 
    360       if (!relax &&
    361           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
    362         *refined = true;
    363       } else if (relax) {
    364         std::vector<ShapeAndType> existing_inputs;
    365         const std::vector<ShapeAndType>* inputs =
    366             node_context->input_handle_shapes_and_types(dst_input);
    367         if (inputs) {
    368           existing_inputs = *inputs;
    369         }
    370         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
    371                                                               *outputs)) {
    372           if (IsUpdatedShapesOrTypes(
    373                   node_context, existing_inputs,
    374                   *node_context->input_handle_shapes_and_types(dst_input))) {
    375             *refined = true;
    376           }
    377         }
    378       }
    379     }
    380   }
    381 
    382   if (!*refined) {
    383     // No input shape has changed, we're done
    384     return Status::OK();
    385   }
    386 
    387   // Get and run the shape function for this node to update the shapes of the
    388   // outputs.
    389   const OpRegistrationData* op_reg_data;
    390   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
    391   if (op_reg_data->shape_inference_fn == nullptr &&
    392       require_shape_inference_fns_) {
    393     return errors::InvalidArgument(
    394         "No shape inference function exists for op '", node->type_string(),
    395         "', did you forget to define it?");
    396   }
    397 
    398   if (!op_reg_data->shape_inference_fn) {
    399     // There is nothing more we can infer
    400     return Status::OK();
    401   }
    402 
    403   return RunShapeFn(node, op_reg_data, node_ext_context);
    404 }
    405 
    406 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
    407                                                    int dst_idx, bool* evaluated,
    408                                                    Tensor* result) {
    409   *evaluated = false;
    410 
    411   const Edge* input_edge;
    412   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
    413 
    414   // Simple case: the source node is a constant
    415   const Node* src = input_edge->src();
    416   if (src->IsConstant()) {
    417     if (result->FromProto(src->def().attr().at("value").tensor())) {
    418       *evaluated = true;
    419       return Status::OK();
    420     }
    421   }
    422 
    423   if (disable_constant_propagation_) {
    424     return Status::OK();
    425   }
    426 
    427   bool is_constant_graph = false;
    428   Graph subgraph(ops_registry_);
    429   auto versions = subgraph.versions();
    430   versions.set_producer(graph_def_version_);
    431   subgraph.set_versions(versions);
    432 
    433   // We identify the possibly constant subgraph to evaluate by
    434   // recursively iterating backwards through the inputs to 'node'
    435   // until we either 1) find an already existing input to our subgraph
    436   // (filled in `const_inputs`), 2) Discover our graph is not constant,
    437   // or 3) Hit a root node.
    438   std::vector<std::pair<string, Tensor>> const_inputs;
    439   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
    440       input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
    441   if (!is_constant_graph) {
    442     return Status::OK();
    443   }
    444   const string output_tensor_name =
    445       strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
    446   std::vector<Tensor> outputs;
    447 
    448   // NOTE; we should pass in a function library runtime if we want
    449   // to support constant-expression evaluation on functions.
    450   Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */,
    451                                const_inputs, {output_tensor_name}, &outputs);
    452 
    453   // If all kernels in the constant graph are not registered
    454   // in the process, GraphRunner::Run may fail, in which case
    455   // we cannot propagate constants, so this is best-effort.
    456   if (s.ok()) {
    457     *result = outputs[0];
    458     *evaluated = true;
    459 
    460     // We memoize (small) constants evaluated so far, so
    461     // ExtractConstantSubgraph can avoid extracting the full
    462     // subgraph.  As we build up large graphs, this avoids
    463     // repeated computation of the early parts of a constant
    464     // graph.
    465     if (outputs[0].TotalBytes() <= kMaxTensorSize) {
    466       const_tensor_map_[output_tensor_name] = outputs[0];
    467     }
    468   }
    469   return Status::OK();
    470 }
    471 
    472 Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
    473                                                            Tensor* output,
    474                                                            bool* success) {
    475   *success = false;
    476   const Node* node = edge->src();
    477   auto it = node_to_context_.find(node);
    478   if (it == node_to_context_.end()) {
    479     return errors::FailedPrecondition("Node does not have context.");
    480   }
    481   InferenceContext* c = it->second->get_context();
    482 
    483   if (node->type_string() == "Shape") {
    484     // If input shapes to the shape op are fully defined,
    485     // we can infer the shape op's output tensor.
    486     bool fully_defined_inputs = c->FullyDefined(c->input(0));
    487     if (fully_defined_inputs) {
    488       int input_rank = c->Rank(c->input(0));
    489       Tensor t(node->output_type(0), TensorShape({input_rank}));
    490       if (node->output_type(0) == DT_INT32) {
    491         auto flat = t.flat<int>();
    492         for (int i = 0; i < input_rank; i++) {
    493           int64 dimension = c->Value(c->Dim(c->input(0), i));
    494           if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
    495             return errors::FailedPrecondition(
    496                 "Shape has output type int32, but dimension exceeds maximum "
    497                 "int32 value");
    498           }
    499           flat(i) = static_cast<int32>(dimension);
    500         }
    501       } else if (node->output_type(0) == DT_INT64) {
    502         auto flat = t.flat<int64>();
    503         for (int i = 0; i < input_rank; i++) {
    504           flat(i) = c->Value(c->Dim(c->input(0), i));
    505         }
    506       } else {
    507         return errors::FailedPrecondition(
    508             "Shape has output type that is not int32 or int64");
    509       }
    510       *output = t;
    511       *success = true;
    512     }
    513   } else if (node->type_string() == "Rank") {
    514     bool rank_known = c->RankKnown(c->input(0));
    515     if (rank_known) {
    516       int32 input_rank = c->Rank(c->input(0));
    517       Tensor t(node->output_type(0), TensorShape({}));
    518       t.flat<int32>()(0) = input_rank;
    519       *output = t;
    520       *success = true;
    521     }
    522   } else if (node->type_string() == "Size") {
    523     bool fully_defined_inputs = c->FullyDefined(c->input(0));
    524     if (fully_defined_inputs) {
    525       int32 rank = c->Rank(c->input(0));
    526       Tensor t(node->output_type(0), TensorShape({}));
    527       int64 size = 1;
    528       for (int i = 0; i < rank; i++) {
    529         size *= c->Value(c->Dim(c->input(0), i));
    530       }
    531       if (node->output_type(0) == DT_INT32) {
    532         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
    533           return errors::FailedPrecondition(
    534               "Size has output type int32, but size exceeds maximum int32 "
    535               "value");
    536         }
    537         t.flat<int32>()(0) = static_cast<int32>(size);
    538       } else if (node->output_type(0) == DT_INT64) {
    539         t.flat<int64>()(0) = size;
    540       } else {
    541         return errors::FailedPrecondition(
    542             "Size has output type that is not int32 or int64");
    543       }
    544       *output = t;
    545       *success = true;
    546     }
    547   }
    548   return Status::OK();
    549 }
    550 
    551 Status ShapeRefiner::ExtractConstantSubgraph(
    552     Node* target_node, Graph* out_graph, bool* is_constant_graph,
    553     std::vector<std::pair<string, Tensor>>* const_inputs) {
    554   *is_constant_graph = false;
    555   std::unordered_set<string> const_inputs_added;
    556 
    557   if (target_node->op_def().is_stateful()) {
    558     return Status::OK();
    559   }
    560 
    561   if (target_node->type_string() == "PlaceholderWithDefault") {
    562     return Status::OK();
    563   }
    564 
    565   // TODO(skyewm): more of the filtering applied in input nodes below should be
    566   // applied to target_node here
    567 
    568   struct NodeAndRecursed {
    569     Node* new_node = nullptr;
    570     bool recursed = false;
    571   };
    572 
    573   std::map<Node*, NodeAndRecursed> old_to_new_and_recursed;
    574   Node* target_node_copy = out_graph->CopyNode(target_node);
    575   old_to_new_and_recursed[target_node].new_node = target_node_copy;
    576   old_to_new_and_recursed[target_node].recursed = true;
    577 
    578   // Add the target node's inputs to seed the recursion.
    579   std::deque<const Edge*> edges_to_visit;
    580   for (const Edge* e : target_node->in_edges()) {
    581     // TODO(vrv): What do we do about control edges?  Based on our
    582     // definition of a constant graph, we should be free to ignore
    583     // control edges since the order in which a constant graph is
    584     // executed should be the same regardless of when nodes run: we
    585     // should only need to recurse down data edges.
    586     if (e->IsControlEdge()) continue;
    587     edges_to_visit.push_back(e);
    588   }
    589 
    590   *is_constant_graph = true;
    591 
    592   // Iterate over the set of edges to visit (backwards).
    593   while (!edges_to_visit.empty()) {
    594     const Edge* current_edge = edges_to_visit.front();
    595     edges_to_visit.pop_front();
    596     Node* current_node = current_edge->src();
    597 
    598     // If the node is stateful, assume the graph is not constant.
    599     if (current_node->op_def().is_stateful()) {
    600       *is_constant_graph = false;
    601       return Status::OK();
    602     }
    603 
    604     // During construction or import from GraphConstructor, back edges may not
    605     // be filled in.  Don't constant fold through merges at all for now.
    606     if (IsMerge(current_node)) {
    607       *is_constant_graph = false;
    608       return Status::OK();
    609     }
    610 
    611     // Don't constant fold enter/exit currently either, as it's easy to end
    612     // up with a partial frame.
    613     if (IsEnter(current_node) || IsExit(current_node)) {
    614       *is_constant_graph = false;
    615       return Status::OK();
    616     }
    617 
    618     // Placeholders should never be constant folded because their outputs are
    619     // fed by the user. Note that "Placeholder" nodes have no inputs so are
    620     // handled below.
    621     if (current_node->type_string() == "PlaceholderWithDefault") {
    622       *is_constant_graph = false;
    623       return Status::OK();
    624     }
    625 
    626     // If there is nothing more to recurse down, see if
    627     // the generator node is a constant.
    628     if (current_node->num_inputs() == 0) {
    629       if (!current_node->IsConstant()) {
    630         // Generator node is not a constant, so subgraph is not
    631         // constant.
    632         *is_constant_graph = false;
    633         return Status::OK();
    634       }
    635     }
    636 
    637     // Either the node is a constant, or the node is a potential
    638     // intermediate node on the path from a constant.
    639     //
    640     // Add a copy of its node and a new edge to the new subgraph.
    641 
    642     // Get or create the version of 'current_node' in the new graph.
    643     Node* current_node_copy;
    644     // This gets or creates the NodeAndRecursed entry for current_node.
    645     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
    646     if (node_and_recursed->new_node == nullptr) {
    647       // First time processing this node.
    648       current_node_copy = out_graph->CopyNode(current_node);
    649       // Track the mapping from the original node to the new one.
    650       node_and_recursed->new_node = current_node_copy;
    651     } else {
    652       current_node_copy = node_and_recursed->new_node;
    653     }
    654 
    655     // Add the edge to the destination node.
    656     {
    657       auto it = old_to_new_and_recursed.find(current_edge->dst());
    658       if (it == old_to_new_and_recursed.end()) {
    659         return errors::Internal(
    660             "Could not find mapping from old to new copy of destination node: ",
    661             current_edge->dst()->name());
    662       }
    663       Node* dst_copy = it->second.new_node;
    664 
    665       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
    666                          dst_copy, current_edge->dst_input());
    667     }
    668 
    669     const string& output_tensor_name =
    670         strings::StrCat(current_node->name(), ":", current_edge->src_output());
    671 
    672     // Some tensor values can be inferred. For example, a shape op
    673     // with input shapes fully defined can have its output tensor inferred.
    674     Tensor tensor_inferred;
    675     bool successfully_inferred_tensor = false;
    676     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
    677         current_edge, &tensor_inferred, &successfully_inferred_tensor));
    678     if (successfully_inferred_tensor) {
    679       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
    680       const_inputs_added.insert(output_tensor_name);
    681       continue;
    682     }
    683 
    684     // If we have a copy of the input tensor materialized already,
    685     // then add to the list of inputs to feed and do not recurse further.
    686     auto it = const_tensor_map_.find(output_tensor_name);
    687     if (it != const_tensor_map_.end() &&
    688         const_inputs_added.count(output_tensor_name) == 0) {
    689       const_inputs->emplace_back(output_tensor_name, it->second);
    690       const_inputs_added.insert(output_tensor_name);
    691       continue;
    692     }
    693 
    694     // If this node's inputs have not been processed already, do so now.
    695     if (!node_and_recursed->recursed) {
    696       node_and_recursed->recursed = true;
    697       for (const Edge* e : current_node->in_edges()) {
    698         if (e->IsControlEdge()) continue;
    699         edges_to_visit.push_back(e);
    700       }
    701     }
    702   }
    703 
    704   return Status::OK();
    705 }
    706 
    707 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
    708                                           const Node* node, int dst_idx,
    709                                           ShapeHandle* result) {
    710   const Edge* input_edge;
    711   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
    712 
    713   InferenceContext* src_context = GetContext(input_edge->src());
    714   if (src_context == nullptr) return errors::Internal("Missing src context");
    715   ShapeHandle src_shape = src_context->output(input_edge->src_output());
    716   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
    717 
    718   const string& src_op = input_edge->src()->type_string();
    719   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
    720     // Source tensor is a vector of length 0, so the shape it
    721     // represents is as scalar.
    722     *result = target_context->Scalar();
    723   } else if (src_op == "Shape") {
    724     *result = src_context->input(0);
    725   } else if (src_op == "ShapeN") {
    726     *result = src_context->input(input_edge->src_output());
    727   } else if (src_op == "Pack") {
    728     std::vector<DimensionHandle> dims;
    729     // Pack is concatenating its input scalars to form the shape tensor vector.
    730     for (int i = 0; i < src_context->num_inputs(); ++i) {
    731       Tensor scalar;
    732       bool evaluated = false;
    733       TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
    734                                                        &evaluated, &scalar));
    735       if (evaluated) {
    736         int64 size;
    737         if (scalar.dtype() == DT_INT32) {
    738           size = scalar.scalar<int32>()();
    739         } else if (scalar.dtype() == DT_INT64) {
    740           size = scalar.scalar<int64>()();
    741         } else {
    742           return errors::InvalidArgument("Pack input must be int32 or int64");
    743         }
    744         dims.push_back(size < 0 ? target_context->UnknownDim()
    745                                 : target_context->MakeDim(size));
    746       } else {
    747         dims.push_back(target_context->UnknownDim());
    748       }
    749     }
    750     *result = target_context->MakeShape(dims);
    751   } else if (src_op == "Concat" || src_op == "ConcatV2") {
    752     *result = target_context->Scalar();
    753     // For Concat, input 0 is concat dim; for V2 it is the last input.
    754     const int concat_dim =
    755         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
    756     // Concat is concatenating its input shape vectors.
    757     for (int i = 0; i < src_context->num_inputs(); ++i) {
    758       // Concat dim is ignored (and will always be a scalar).
    759       if (i == concat_dim) continue;
    760       ShapeHandle sub_result;
    761       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
    762                                               i, &sub_result));
    763       if (!target_context->RankKnown(sub_result)) {
    764         // Failed to evaluate. Treat the output as completely unknown.
    765         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
    766         // figure that rank out and append the right number of unknown dims.
    767         *result = target_context->UnknownShape();
    768         return Status::OK();
    769       }
    770       TF_RETURN_IF_ERROR(
    771           target_context->Concatenate(*result, sub_result, result));
    772     }
    773   } else {
    774     Tensor t;
    775     bool evaluated = false;
    776     TF_RETURN_IF_ERROR(
    777         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
    778     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
    779         evaluated ? &t : nullptr, src_shape, result));
    780   }
    781   return Status::OK();
    782 }
    783 
    784 Status ShapeRefiner::RunShapeFn(const Node* node,
    785                                 const OpRegistrationData* op_reg_data,
    786                                 ExtendedInferenceContext* ec) {
    787   // This will be filled in with real data in a second pass.
    788   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
    789   std::vector<Tensor> real_tensors(node->num_inputs());
    790   std::vector<bool> attempted_materialization(node->num_inputs());
    791   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
    792   std::vector<ShapeHandle> input_tensors_as_shapes;
    793 
    794   auto* c = ec->get_context();
    795 
    796   c->set_input_tensors(input_tensors);
    797   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
    798 
    799   // Run the shape inference function, and return if there was an error.
    800   // Capture as lambda, because we might need to re-run inference later on.
    801   auto run_inference_lambda = [&]() {
    802     if (function_library_ && op_reg_data->is_function_op) {
    803       // Special inference logic for user-defined functions.
    804 
    805       auto* func_def = function_library_->Find(op_reg_data->op_def.name());
    806       if (func_def) {
    807         return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
    808                                       ec);
    809       }
    810     }
    811 
    812     if (op_reg_data->shape_inference_fn) {
    813       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
    814     } else {
    815       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
    816     }
    817     return Status::OK();
    818   };
    819   TF_RETURN_IF_ERROR(run_inference_lambda());
    820 
    821   // We must run the shape function repeatedly, in case users write
    822   // shape functions where they only conditionally call input_tensor()
    823   // based on the values of another input tensor.
    824   bool rerun_shape_fn;
    825   do {
    826     // If the result of running shape inference would have benefitted
    827     // from knowing the values of input tensors, try to materialize
    828     // the results of those tensors, and then run the shape inference
    829     // function again using those known tensors.
    830     rerun_shape_fn = false;
    831 
    832     // NOTE: It is possible to batch the extraction and
    833     // materialization of inputs, instead of materializing one input
    834     // at a time like we do below.  If input-at-a-time computation
    835     // becomes a bottleneck, we could separate ExtractConstantSubgraph
    836     // into two functions: one that returns true if an input is
    837     // derivable from constants, and another function that extracts
    838     // the subgraph for multiple target nodes and executes the whole
    839     // subgraph once.
    840 
    841     for (int i = 0; i < c->num_inputs(); ++i) {
    842       if (!c->requested_input_tensor(i)) {
    843         continue;
    844       }
    845       // Check if we have not already filled in the requested input,
    846       // and if not, try to materialize the tensors.
    847       if (!attempted_materialization[i]) {
    848         attempted_materialization[i] = true;
    849 
    850         Tensor result;
    851         bool evaluated = false;
    852         TF_RETURN_IF_ERROR(
    853             EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
    854         if (evaluated) {
    855           real_tensors[i] = result;
    856           input_tensors[i] = &real_tensors[i];
    857           // We have more concrete information about a shape,
    858           // so re-run shape inference.
    859           rerun_shape_fn = true;
    860         }
    861       }
    862       if (c->requested_input_tensor_as_partial_shape(i) &&
    863           !attempted_tensor_as_shape_conversion[i]) {
    864         attempted_tensor_as_shape_conversion[i] = true;
    865         if (i >= input_tensors_as_shapes.size()) {
    866           input_tensors_as_shapes.resize(i + 1);
    867         }
    868         ShapeHandle s;
    869         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
    870         input_tensors_as_shapes[i] = s;
    871         rerun_shape_fn = true;
    872       }
    873     }
    874 
    875     if (rerun_shape_fn) {
    876       // We have more information about the shapes on this pass,
    877       // so re-run shape inference.
    878       c->set_input_tensors(input_tensors);
    879       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
    880       TF_RETURN_IF_ERROR(run_inference_lambda());
    881     }
    882   } while (rerun_shape_fn);
    883 
    884   return Status::OK();
    885 }
    886 
    887 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
    888                                     ShapeHandle s1) {
    889   if (s0.SameHandle(s1)) {
    890     return true;
    891   }
    892   if (c->Rank(s0) != c->Rank(s1)) {
    893     return false;
    894   }
    895   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
    896     return false;
    897   }
    898   for (int i = 0; i < c->Rank(s0); ++i) {
    899     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
    900       int64 val0 = c->Value(c->Dim(s0, i));
    901       int64 val1 = c->Value(c->Dim(s1, i));
    902       if (val0 < 0 || val1 < 0 || val0 != val1) {
    903         return false;
    904       }
    905     }
    906   }
    907 
    908   return true;
    909 }
    910 
    911 bool ShapeRefiner::IsUpdatedShapesOrTypes(
    912     InferenceContext* c, const std::vector<ShapeAndType>& existing,
    913     const std::vector<ShapeAndType>& updated) {
    914   if (existing.size() != updated.size()) {
    915     return true;
    916   }
    917   for (int i = 0; i < existing.size(); i++) {
    918     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
    919         existing[i].dtype != updated[i].dtype) {
    920       return true;
    921     }
    922   }
    923   return false;
    924 }
    925 
    926 }  // namespace tensorflow
    927