Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
     17 
     18 #include <deque>
     19 
     20 #include "tensorflow/core/common_runtime/graph_runner.h"
     21 #include "tensorflow/core/common_runtime/shape_refiner.h"
     22 #include "tensorflow/core/framework/bounds_check.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/versions.pb.h"
     26 #include "tensorflow/core/graph/graph.h"
     27 
     28 namespace tensorflow {
     29 
     30 using shape_inference::InferenceContext;
     31 
     32 namespace {
     33 
     34 // Tries to infer tensor output based on the input shapes of the node. In some
     35 // cases, the shapes of the inputs are sufficient for inferring the contents of
     36 // the output tensor. For example, a Shape op with fully defined input shapes
     37 // can have its output tensor inferred.
     38 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
     39                                              const ShapeRefiner& refiner,
     40                                              Tensor* output, bool* success) {
     41   *success = false;
     42   const Node* node = edge.src();
     43   InferenceContext* c = refiner.GetContext(node);
     44   if (c == nullptr) {
     45     return errors::FailedPrecondition("Node does not have context.");
     46   }
     47 
     48   if (node->type_string() == "Shape") {
     49     // If input shapes to the shape op are fully defined,
     50     // we can infer the shape op's output tensor.
     51     bool fully_defined_inputs = c->FullyDefined(c->input(0));
     52     if (fully_defined_inputs) {
     53       int input_rank = c->Rank(c->input(0));
     54       Tensor t(node->output_type(0), TensorShape({input_rank}));
     55       if (node->output_type(0) == DT_INT32) {
     56         auto flat = t.flat<int>();
     57         for (int i = 0; i < input_rank; i++) {
     58           int64 dimension = c->Value(c->Dim(c->input(0), i));
     59           if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
     60             return errors::InvalidArgument(
     61                 "Shape has output type int32, but dimension exceeds maximum "
     62                 "int32 value");
     63           }
     64           flat(i) = static_cast<int32>(dimension);
     65         }
     66       } else if (node->output_type(0) == DT_INT64) {
     67         auto flat = t.flat<int64>();
     68         for (int i = 0; i < input_rank; i++) {
     69           flat(i) = c->Value(c->Dim(c->input(0), i));
     70         }
     71       } else {
     72         return errors::FailedPrecondition(
     73             "Shape has output type that is not int32 or int64");
     74       }
     75       *output = t;
     76       *success = true;
     77     }
     78   } else if (node->type_string() == "Rank") {
     79     bool rank_known = c->RankKnown(c->input(0));
     80     if (rank_known) {
     81       int32 input_rank = c->Rank(c->input(0));
     82       Tensor t(node->output_type(0), TensorShape({}));
     83       t.flat<int32>()(0) = input_rank;
     84       *output = t;
     85       *success = true;
     86     }
     87   } else if (node->type_string() == "Size") {
     88     bool fully_defined_inputs = c->FullyDefined(c->input(0));
     89     if (fully_defined_inputs) {
     90       int32 rank = c->Rank(c->input(0));
     91       Tensor t(node->output_type(0), TensorShape({}));
     92       int64 size = 1;
     93       for (int i = 0; i < rank; i++) {
     94         size *= c->Value(c->Dim(c->input(0), i));
     95       }
     96       if (node->output_type(0) == DT_INT32) {
     97         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
     98           return errors::InvalidArgument(
     99               "Size has output type int32, but size exceeds maximum int32 "
    100               "value");
    101         }
    102         t.flat<int32>()(0) = static_cast<int32>(size);
    103       } else if (node->output_type(0) == DT_INT64) {
    104         t.flat<int64>()(0) = size;
    105       } else {
    106         return errors::FailedPrecondition(
    107             "Size has output type that is not int32 or int64");
    108       }
    109       *output = t;
    110       *success = true;
    111     }
    112   }
    113   return Status::OK();
    114 }
    115 
    116 // Returns true if 'node' has a registered CPU kernel.
    117 bool HasCpuKernel(const Node& node) {
    118   return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
    119                        /*kernel_class_name=*/nullptr)
    120       .ok();
    121 }
    122 
    123 // Extracts the subgraph ending at 'target_node' that is statically computable
    124 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
    125 // will be set to true.
    126 Status ExtractConstantSubgraph(
    127     const Node& target_node, const ShapeRefiner& refiner,
    128     const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
    129     bool* is_constant_graph,
    130     std::vector<std::pair<string, Tensor>>* const_inputs) {
    131   *is_constant_graph = false;
    132   std::unordered_set<string> const_inputs_added;
    133 
    134   if (target_node.op_def().is_stateful()) {
    135     return Status::OK();
    136   }
    137 
    138   if (IsMerge(&target_node)) {
    139     return Status::OK();
    140   }
    141 
    142   if (target_node.type_string() == "PlaceholderWithDefault") {
    143     return Status::OK();
    144   }
    145 
    146   // Since constant-folding runs on the CPU, do not attempt to constant-fold
    147   // operators that have no CPU kernel.
    148   if (!HasCpuKernel(target_node)) {
    149     return Status::OK();
    150   }
    151 
    152   // TODO(skyewm): should more of the filtering applied in input nodes below be
    153   // applied to target_node here?
    154 
    155   // Identify the possibly constant subgraph by recursively iterating backwards
    156   // through the inputs to 'target_node' until we either 1) find an already
    157   // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
    158   // constant, or 3) Hit a root node.
    159 
    160   struct NodeAndRecursed {
    161     Node* new_node = nullptr;
    162     bool recursed = false;
    163   };
    164 
    165   std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
    166   Node* target_node_copy = out_graph->CopyNode(&target_node);
    167   old_to_new_and_recursed[&target_node].new_node = target_node_copy;
    168   old_to_new_and_recursed[&target_node].recursed = true;
    169 
    170   // Add the target node's inputs to seed the recursion.
    171   std::deque<const Edge*> edges_to_visit;
    172   for (const Edge* e : target_node.in_edges()) {
    173     // TODO(skyewm): control edges will be meaningful if/when we handle control
    174     // flow (e.g. constants in cond branches are triggered via control edges).
    175     if (e->IsControlEdge()) continue;
    176     edges_to_visit.push_back(e);
    177   }
    178 
    179   *is_constant_graph = true;
    180 
    181   // Iterate over the set of edges to visit (backwards).
    182   while (!edges_to_visit.empty()) {
    183     const Edge* current_edge = edges_to_visit.front();
    184     edges_to_visit.pop_front();
    185     Node* current_node = current_edge->src();
    186 
    187     // If the node is stateful, assume the graph is not constant.
    188     if (current_node->op_def().is_stateful()) {
    189       *is_constant_graph = false;
    190       return Status::OK();
    191     }
    192 
    193     // During construction or import from GraphConstructor, back edges may not
    194     // be filled in. In addition, control flow constructs may depend on control
    195     // edges which aren't handled by this method. Don't constant fold through
    196     // merges at all for now.
    197     if (IsMerge(current_node)) {
    198       *is_constant_graph = false;
    199       return Status::OK();
    200     }
    201 
    202     // Don't constant fold enter/exit currently either, as it's easy to end
    203     // up with a partial frame.
    204     if (IsEnter(current_node) || IsExit(current_node)) {
    205       *is_constant_graph = false;
    206       return Status::OK();
    207     }
    208 
    209     // Placeholders should never be constant folded because their outputs are
    210     // fed by the user. Note that "Placeholder" nodes have no inputs so are
    211     // handled below.
    212     if (current_node->type_string() == "PlaceholderWithDefault") {
    213       *is_constant_graph = false;
    214       return Status::OK();
    215     }
    216 
    217     if (!HasCpuKernel(*current_node)) {
    218       *is_constant_graph = false;
    219       return Status::OK();
    220     }
    221 
    222     // If there is nothing more to recurse down, see if
    223     // the generator node is a constant.
    224     if (current_node->num_inputs() == 0) {
    225       if (!current_node->IsConstant()) {
    226         // Generator node is not a constant, so subgraph is not
    227         // constant.
    228         *is_constant_graph = false;
    229         return Status::OK();
    230       }
    231     }
    232 
    233     // Either the node is a constant, or the node is a potential
    234     // intermediate node on the path from a constant.
    235     //
    236     // Add a copy of its node and a new edge to the new subgraph.
    237 
    238     // Get or create the version of 'current_node' in the new graph.
    239     Node* current_node_copy;
    240     // This gets or creates the NodeAndRecursed entry for current_node.
    241     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
    242     if (node_and_recursed->new_node == nullptr) {
    243       // First time processing this node.
    244       current_node_copy = out_graph->CopyNode(current_node);
    245       // Track the mapping from the original node to the new one.
    246       node_and_recursed->new_node = current_node_copy;
    247     } else {
    248       current_node_copy = node_and_recursed->new_node;
    249     }
    250 
    251     // Add the edge to the destination node.
    252     {
    253       auto it = old_to_new_and_recursed.find(current_edge->dst());
    254       if (it == old_to_new_and_recursed.end()) {
    255         return errors::Internal(
    256             "Could not find mapping from old to new copy of destination node: ",
    257             current_edge->dst()->name());
    258       }
    259       Node* dst_copy = it->second.new_node;
    260 
    261       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
    262                          dst_copy, current_edge->dst_input());
    263     }
    264 
    265     const string& output_tensor_name =
    266         strings::StrCat(current_node->name(), ":", current_edge->src_output());
    267 
    268     // Some tensor values can be inferred. For example, a shape op
    269     // with input shapes fully defined can have its output tensor inferred.
    270     Tensor tensor_inferred;
    271     bool successfully_inferred_tensor = false;
    272     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
    273         *current_edge, refiner, &tensor_inferred,
    274         &successfully_inferred_tensor));
    275     if (successfully_inferred_tensor) {
    276       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
    277       const_inputs_added.insert(output_tensor_name);
    278       continue;
    279     }
    280 
    281     // If we have a copy of the input tensor materialized already,
    282     // then add to the list of inputs to feed and do not recurse further.
    283     if (cached_values != nullptr) {
    284       auto it = cached_values->find(output_tensor_name);
    285       if (it != cached_values->end() &&
    286           const_inputs_added.count(output_tensor_name) == 0) {
    287         const_inputs->emplace_back(output_tensor_name, it->second);
    288         const_inputs_added.insert(output_tensor_name);
    289         continue;
    290       }
    291     }
    292 
    293     // If this node's inputs have not been processed already, do so now.
    294     if (!node_and_recursed->recursed) {
    295       node_and_recursed->recursed = true;
    296       for (const Edge* e : current_node->in_edges()) {
    297         if (e->IsControlEdge()) continue;
    298         edges_to_visit.push_back(e);
    299       }
    300     }
    301   }
    302 
    303   return Status::OK();
    304 }
    305 
    306 }  // namespace
    307 
    308 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
    309                               const OpRegistryInterface& ops,
    310                               int32 graph_def_version, bool* evaluated,
    311                               Tensor* result, GraphRunner* graph_runner,
    312                               std::unordered_map<string, Tensor>* cached_values,
    313                               int64 max_cached_value_size,
    314                               bool disable_constant_propagation) {
    315   *evaluated = false;
    316   const Node* src = tensor.node;
    317 
    318   // Simple case: the source node is a constant
    319   if (src->IsConstant()) {
    320     if (result->FromProto(src->def().attr().at("value").tensor())) {
    321       *evaluated = true;
    322       return Status::OK();
    323     }
    324   }
    325 
    326   if (disable_constant_propagation) {
    327     return Status::OK();
    328   }
    329 
    330   bool is_constant_graph = false;
    331   Graph subgraph(&ops);
    332   auto versions = subgraph.versions();
    333   versions.set_producer(graph_def_version);
    334   subgraph.set_versions(versions);
    335 
    336   std::vector<std::pair<string, Tensor>> const_inputs;
    337   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
    338                                              &subgraph, &is_constant_graph,
    339                                              &const_inputs));
    340   if (!is_constant_graph) {
    341     return Status::OK();
    342   }
    343   const string output_tensor_name =
    344       strings::StrCat(src->name(), ":", tensor.index);
    345   std::vector<Tensor> outputs;
    346 
    347   std::unique_ptr<GraphRunner> graph_runner_storage;
    348   if (graph_runner == nullptr) {
    349     // TODO(skyewm): Convert to std::make_unique when available.
    350     graph_runner_storage.reset(new GraphRunner(Env::Default()));
    351     graph_runner = graph_runner_storage.get();
    352   }
    353 
    354   // NOTE; we should pass in a function library runtime if we want
    355   // to support constant-expression evaluation on functions.
    356   Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
    357                                const_inputs, {output_tensor_name}, &outputs);
    358 
    359   // If all kernels in the constant graph are not registered
    360   // in the process, GraphRunner::Run may fail, in which case
    361   // we cannot propagate constants, so this is best-effort.
    362   if (s.ok()) {
    363     *result = outputs[0];
    364     *evaluated = true;
    365 
    366     // We memoize (small) constants evaluated so far, so
    367     // ExtractConstantSubgraph can avoid extracting the full
    368     // subgraph.  As we build up large graphs, this avoids
    369     // repeated computation of the early parts of a constant
    370     // graph.
    371     if (cached_values != nullptr &&
    372         outputs[0].TotalBytes() <= max_cached_value_size) {
    373       (*cached_values)[output_tensor_name] = outputs[0];
    374     }
    375   }
    376   return Status::OK();
    377 }
    378 
    379 }  // namespace tensorflow
    380