Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
     16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/common_runtime/graph_runner.h"
     21 #include "tensorflow/core/framework/function.pb.h"
     22 #include "tensorflow/core/framework/shape_inference.h"
     23 #include "tensorflow/core/graph/graph.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 
     27 namespace tensorflow {
     28 namespace grappler {
     29 class GraphProperties;
     30 }
     31 
     32 // This class stores extra inference information in addition to
     33 // InferenceContext, such as inference tree for user-defined functions and node
     34 // input and output types.
     35 class ExtendedInferenceContext {
     36  public:
     37   ExtendedInferenceContext(
     38       std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node)
     39       : inference_context_(std::move(ic)) {
     40     input_types_.reserve(node->num_inputs());
     41     for (int i = 0; i < node->num_inputs(); i++) {
     42       input_types_.push_back(node->input_type(i));
     43     }
     44     output_types_.reserve(node->num_outputs());
     45     for (int i = 0; i < node->num_outputs(); i++) {
     46       output_types_.push_back(node->output_type(i));
     47     }
     48   }
     49 
     50   const std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>&
     51   nested_inferences() const {
     52     return nested_inferences_;
     53   }
     54   DataType input_type(int64 idx) const { return input_types_[idx]; }
     55   DataType output_type(int64 idx) const { return output_types_[idx]; }
     56 
     57   shape_inference::InferenceContext* get_context() {
     58     return inference_context_.get();
     59   }
     60 
     61   // Sets nested inference info.
     62   // For composite ops (user-defined functions) only.
     63   // Inference for trivial ops must not call this setter.
     64   void set_nested_inferences(
     65       std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
     66           inferences) {
     67     nested_inferences_ = std::move(inferences);
     68   }
     69 
     70  private:
     71   std::unique_ptr<shape_inference::InferenceContext> inference_context_;
     72   std::vector<DataType> input_types_;
     73   std::vector<DataType> output_types_;
     74 
     75   // Nested inferences for composite ops (user-defined functions).
     76   // Mapping key is nested node name.
     77   // For trivial ops this map must be empty.
     78   std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
     79       nested_inferences_;
     80 
     81   TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext);
     82 };
     83 
     84 // ShapeRefiner performs shape inference for TensorFlow Graphs.  It is
     85 // responsible for instantiating InferenceContext objects for each
     86 // Node in the Graph, and providing/storing the 'input_tensor' Tensors
     87 // used by Shape Inference functions, when available at graph
     88 // construction time.
     89 class ShapeRefiner {
     90  public:
     91   ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
     92 
     93   // Same as ShapeRefiner(versions.producer(), ops)
     94   ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
     95 
     96   ~ShapeRefiner();
     97 
     98   // Performs validation of 'node' and runs 'node's shape function,
     99   // storing its shape outputs.
    100   //
    101   // All inputs of 'node' must be added to ShapeRefiner prior to
    102   // adding 'node'.
    103   //
    104   // Returns an error if:
    105   //  - the shape function for 'node' was not registered.
    106   //  - 'node' was added before its inputs.
    107   //  - The shape inference function returns an error.
    108   Status AddNode(const Node* node);
    109 
    110   // Sets 'node's 'output_port' output to have shape 'shape'.
    111   //
    112   // Returns an error if 'node' was not previously added to this
    113   // object, if 'output_port' is invalid, or if 'shape' is
    114   // not compatible with the existing shape of the output.
    115   Status SetShape(const Node* node, int output_port,
    116                   shape_inference::ShapeHandle shape);
    117 
    118   // Update the input shapes of node in case the shapes of the fan-ins of 'node'
    119   // have themselves been modified (For example, in case of incremental shape
    120   // refinement). If 'relax' is true, a new shape with the broadest set of
    121   // information will be set as the new input (see InferenceContext::RelaxInput
    122   // for full details and examples). Sets refined to true if any shapes have
    123   // changed (in their string representations). Note that shapes may have been
    124   // updated to newer versions (but with identical string representations) even
    125   // if <*refined> is set to false.
    126   Status UpdateNode(const Node* node, bool relax, bool* refined);
    127 
    128   // Returns the InferenceContext for 'node', if present.
    129   shape_inference::InferenceContext* GetContext(const Node* node) const {
    130     auto it = node_to_context_.find(node);
    131     if (it == node_to_context_.end()) {
    132       return nullptr;
    133     }
    134     return it->second->get_context();
    135   }
    136 
    137   // Returns the ExtendedInferenceContext for 'node', if present.
    138   ExtendedInferenceContext* GetExtendedContext(const Node* node) const {
    139     auto it = node_to_context_.find(node);
    140     if (it == node_to_context_.end()) {
    141       return nullptr;
    142     }
    143     return it->second.get();
    144   }
    145 
    146   // Getters and setters for graph_def_version_.
    147   int32 graph_def_version() const { return graph_def_version_; }
    148   void set_graph_def_version(int32 version) { graph_def_version_ = version; }
    149 
    150   void set_require_shape_inference_fns(bool require_shape_inference_fns) {
    151     require_shape_inference_fns_ = require_shape_inference_fns;
    152   }
    153   void set_disable_constant_propagation(bool disable) {
    154     disable_constant_propagation_ = disable;
    155   }
    156 
    157   // Set function library to enable function shape inference.
    158   // Without function library, function inference always yields unknown shapes.
    159   // With this enabled, shape inference can take more time since it descends
    160   // into all function calls. It doesn't do inference once for each function
    161   // definition, but once for each function call.
    162   // The function library must outlive the shape refiner.
    163   void set_function_library_for_shape_inference(
    164       const tensorflow::FunctionLibraryDefinition* lib) {
    165     function_library_ = lib;
    166   }
    167 
    168   bool function_shape_inference_supported() const {
    169     return function_library_ != nullptr;
    170   }
    171 
    172   // Call this to keep nested shapes information for user-defined functions:
    173   // nested inferences will be available on the ExtendedInferenceContext for
    174   // each function node, forming a tree of shape inferences corresponding to the
    175   // tree of nested function calls. By default this setting is disabled, and
    176   // only the shapes for the top-level function node will be reported on the
    177   // InferenceContext for each function node, to reduce memory usage.
    178   //
    179   // This flag has no effect when the function inference is not enabled via
    180   // set_function_library_for_shape_inference.
    181   void set_keep_nested_shape_inferences() {
    182     keep_nested_shape_inferences_ = true;
    183   }
    184 
    185  private:
    186   friend class ShapeRefinerTest;
    187   friend class ::tensorflow::grappler::GraphProperties;
    188 
    189   // Returns true if the ranks and all dimensions of <s0> and <s1> are either
    190   // equal in value or both unknown.
    191   static bool SameDefinedShape(shape_inference::InferenceContext* c,
    192                                shape_inference::ShapeHandle s0,
    193                                shape_inference::ShapeHandle s1);
    194 
    195   // Returns true if the shapes and types stored in <*existing> are identical in
    196   // value to the shapes and types in <*updated>.
    197   static bool IsUpdatedShapesOrTypes(
    198       shape_inference::InferenceContext* c,
    199       const std::vector<shape_inference::ShapeAndType>& existing,
    200       const std::vector<shape_inference::ShapeAndType>& updated);
    201 
    202   // Performs shape inference for the given function_def within the
    203   // given outer_context. Internally it instantiates the function as a graph
    204   // and runs shape inference recursively on it with the input shapes provided
    205   // by the outer_context.
    206   //
    207   // Returns an error if:
    208   // - number of inputs/outputs on outer_context doesn't match the function_def
    209   //
    210   // On success:
    211   // - outer_context will contain output shapes inferred from input shapes
    212   // - outer_context will contain nested inferences collection, iff
    213   //   keep_nested_shapes is true
    214   Status InferShapesForFunction(const tensorflow::FunctionDef* function_def,
    215                                 bool keep_nested_shapes,
    216                                 ExtendedInferenceContext* outer_context);
    217 
    218   // Tries to infer tensor output based on the input shapes of the node. In some
    219   // cases, the shapes of the inputs are sufficient for inferring the contents
    220   // of the output tensor. For example, a Shape op with fully defined input
    221   // shapes can have its output tensor inferred.
    222   Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output,
    223                                                bool* success);
    224 
    225   // Extracts the subgraph ending at 'node' that is statically
    226   // computable and inserts into 'out_graph'. If statically computable,
    227   // 'is_constant_graph' will be true.
    228   Status ExtractConstantSubgraph(
    229       Node* node, Graph* out_graph, bool* is_constant_graph,
    230       std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT;
    231 
    232   Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
    233                                        bool* evaluated, Tensor* result);
    234 
    235   // This function tries to materialize as much information about the 'node''s
    236   // dst_idx input as a statically computable shape, and the result may be
    237   // partially known, depending on what is statically inferable.
    238   //
    239   // This is called when node.input[dst_idx] is a tensor that is used to define
    240   // the shape of some other tensor (e.g., the second argument to Reshape is a
    241   // <shape> tensor, where each element of the shape tensor is a dimension of
    242   // the target tensor).  It returns in <result> a shape for that input.
    243   //
    244   // Unlike simply resolving node.input[dst_idx] to a constant and then
    245   // converting that to a shape, this function can return a partial shape. This
    246   // is useful for cases where the shape tensor is only partially defined, such
    247   // as with calls for: reshape(x, shape(y)) where shape(y) is partially
    248   // defined.
    249   //
    250   // The implementation has op implementations for ops commonly called on shape
    251   // tensors, and the implementations are specialized to shape tensors (namely,
    252   // the output is a vector).
    253   //
    254   // <target_context> is used when creating new DimensionHandle and ShapeHandle
    255   // objects.
    256   Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
    257                               const Node* node, int dst_idx,
    258                               shape_inference::ShapeHandle* result);
    259 
    260   Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
    261                     ExtendedInferenceContext* ec);
    262 
    263   int32 graph_def_version_;
    264   const OpRegistryInterface* const ops_registry_;
    265 
    266   // The lifetime of the tensors are bound to the runner, so it should be the
    267   // deleted after the tensors.
    268   GraphRunner graph_runner_;
    269 
    270   // Stores a map from a node to its ExtendedInferenceContext.
    271   std::unordered_map<const Node*, std::unique_ptr<ExtendedInferenceContext>>
    272       node_to_context_;
    273 
    274   // Holds a cache from 'tensor name' to the tensor that is
    275   // evaluatable as a constant expression.  This reduces repeated
    276   // execution of the entire constant subgraph as a graph is being
    277   // built up.  This could be changed to some kind of size-based LRU
    278   // cache to avoid consuming too much memory, if that eventually
    279   // becomes a concern.
    280   //
    281   // Only tensors less than 1KiB are currently stored in the cache.
    282   static constexpr int64 kMaxTensorSize = 1024;
    283   std::unordered_map<string, Tensor> const_tensor_map_;
    284 
    285   bool require_shape_inference_fns_ = true;
    286   bool disable_constant_propagation_ = false;
    287 
    288   // Function library is optional, but has to be set to enable function
    289   // shape inference.
    290   const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr;
    291 
    292   // Determines whether to keep the nested shape inference info for user-
    293   // defined functions. By default that info is discarded to save memory.
    294   bool keep_nested_shape_inferences_ = false;
    295 
    296   // Cache the graph corresponding to each functin definition for which shapes
    297   // are refined.
    298   std::unordered_map<const FunctionDef*, std::unique_ptr<const Graph>>
    299       functions_;
    300 
    301   TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
    302 };
    303 
    304 }  // namespace tensorflow
    305 
    306 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
    307