Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
     16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/node_def_util.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 
     27 namespace tensorflow {
     28 
     29 class ShapeRefiner;
     30 class ShapeRefinerTest;
     31 
     32 namespace grappler {
     33 class GraphProperties;
     34 class SymbolicShapeManager;
     35 }  // namespace grappler
     36 
     37 namespace shape_inference {
     38 
     39 struct DimensionOrConstant;
     40 class InferenceContext;
     41 
     42 // Dimension values are accessed through InferenceContext.
     43 class Dimension {
     44  private:
     45   Dimension();
     46   Dimension(int64 value);
     47   ~Dimension() {}
     48 
     49   const int64 value_;
     50 
     51   friend class InferenceContext;
     52   friend class ShapeManager;
     53   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
     54 };
     55 
     56 class DimensionHandle {
     57  public:
     58   DimensionHandle() {}
     59   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
     60   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
     61 
     62  private:
     63   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
     64 
     65   const Dimension* operator->() const { return ptr_; }
     66   bool IsSet() const { return ptr_ != nullptr; }
     67 
     68   const Dimension* ptr_ = nullptr;
     69 
     70   friend struct DimensionOrConstant;
     71   friend class InferenceContext;
     72   friend class ShapeInferenceTest;
     73   friend class ShapeInferenceTestutil;
     74   friend class ::tensorflow::ShapeRefinerTest;
     75   friend class ShapeManager;
     76   friend class ::tensorflow::grappler::GraphProperties;
     77   friend class ::tensorflow::grappler::SymbolicShapeManager;
     78 
     79   // Intentionally copyable.
     80 };
     81 
     82 // Shape rank and dimensions are accessed through InferenceContext.
     83 class Shape {
     84  private:
     85   Shape();
     86   Shape(const std::vector<DimensionHandle>& dims);
     87   ~Shape() {}
     88 
     89   const int32 rank_;
     90   const std::vector<DimensionHandle> dims_;
     91 
     92   friend class InferenceContext;
     93   friend class ShapeManager;
     94   friend class ::tensorflow::grappler::SymbolicShapeManager;
     95 
     96   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
     97 };
     98 
     99 class ShapeHandle {
    100  public:
    101   ShapeHandle() {}
    102   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
    103   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
    104 
    105  private:
    106   ShapeHandle(const Shape* shape) { ptr_ = shape; }
    107   const Shape* operator->() const { return ptr_; }
    108   bool IsSet() const { return ptr_ != nullptr; }
    109 
    110   const Shape* ptr_ = nullptr;
    111 
    112   friend class InferenceContext;
    113   friend class ShapeInferenceTest;
    114   friend class ShapeInferenceTestutil;
    115   friend class ::tensorflow::ShapeRefinerTest;
    116   friend class ShapeManager;
    117   friend class ::tensorflow::grappler::SymbolicShapeManager;
    118 
    119   // Intentionally copyable.
    120 };
    121 
    122 // Struct used to allow functions to take DimensionHandle or a dimension value.
    123 // Not meant to be constructed directly.
    124 struct DimensionOrConstant {
    125  public:
    126   // Intentionally not explicit.
    127   DimensionOrConstant(DimensionHandle dim);
    128 
    129   // val must be non-negative or InferenceContext::kUnknownDim.
    130   DimensionOrConstant(int64 val);
    131 
    132   // dim takes precedence. If dim != nullptr, val is ignored.
    133   DimensionHandle dim;
    134   int64 val;
    135 
    136  private:
    137   DimensionOrConstant();
    138 };
    139 
    140 struct ShapeAndType {
    141   ShapeAndType() {}
    142   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
    143 
    144   ShapeHandle shape;
    145   DataType dtype = DT_INVALID;
    146 };
    147 
    148 // Shape inference functions registered on ops in REGISTER_OP implement
    149 // their shape functions in terms of this InferenceContext.  An InferenceContext
    150 // is created by the framework and passed to a shape inference function.  The
    151 // shape inference function calls functions on the context, and should call
    152 // set_output() to set the shape on all outputs.
    153 //
    154 // To infer shapes for user-defined functions see ShapeRefiner.
    155 //
    156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
    157 // by the InferenceContext.
    158 class InferenceContext {
    159  public:
    160   static constexpr int64 kUnknownDim = -1;
    161   static constexpr int32 kUnknownRank = -1;
    162 
    163   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    164   //
    165   // Elements of <input_tensors_as_shapes> are used for when a shape function
    166   // makes a call to MakeShapeFromShapeTensor; in particular, when the
    167   // input_tensors[i] is nullptr but the shape represented by it is partially
    168   // known from analysis of the graph.
    169   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
    170   // Values of <input_tensors_as_shapes> do not need to outlive the context.
    171   //
    172   // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
    173   InferenceContext(int graph_def_version, const NodeDef* node_def,
    174                    const OpDef& op_def,
    175                    const std::vector<ShapeHandle>& input_shapes,
    176                    const std::vector<const Tensor*>& input_tensors,
    177                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
    178                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    179                        input_handle_shapes_and_types);
    180 
    181   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    182   //
    183   // Elements of <input_tensors_as_shapes> are used for when a shape
    184   // function makes a call to MakeShapeFromShapeTensor; in particular, when
    185   // the input_tensors[i] is nullptr but the shape represented by it is
    186   // partially known from analysis of the graph. <input_tensors_as_shapes>
    187   // can have fewer elements than <input_shapes>. Values of
    188   // <input_tensors_as_shapes> do not need to outlive the context.
    189   //
    190   // REQUIRES: <node_def> is not NULL, and must outlive the
    191   // InferenceContext.
    192   InferenceContext(
    193       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
    194       const std::vector<TensorShapeProto>& input_shapes,
    195       const std::vector<const Tensor*>& input_tensors,
    196       const std::vector<TensorShapeProto>& input_tensors_as_shapes,
    197       const std::vector<
    198           std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
    199           input_handle_shapes_and_types);
    200 
    201   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    202   //
    203   // Elements of <input_tensors_as_shapes> are used for when a shape
    204   // function makes a call to MakeShapeFromShapeTensor; in particular, when
    205   // the input_tensors[i] is nullptr but the shape represented by it is
    206   // partially known from analysis of the graph. <input_tensors_as_shapes>
    207   // can have fewer elements than <input_shapes>. Values of
    208   // <input_tensors_as_shapes> do not need to outlive the context.
    209   //
    210   // REQUIRES: <node_def> is not NULL, and must outlive the
    211   // InferenceContext.
    212   InferenceContext(
    213       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
    214       const std::vector<PartialTensorShape>& input_shapes,
    215       const std::vector<const Tensor*>& input_tensors,
    216       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
    217       const std::vector<std::unique_ptr<
    218           std::vector<std::pair<PartialTensorShape, DataType>>>>&
    219           input_handle_shapes_and_types);
    220 
    221   ~InferenceContext();
    222 
    223   // Runs the shape inference function 'fn' with 'this' as the
    224   // argument, returns the status of the inference.
    225   //
    226   // On error, additional context is provided in the error message.
    227   Status Run(
    228       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
    229 
    230   // Merge the stored shape of the input in position idx with <shape> according
    231   // to the following rules:
    232   //
    233   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
    234   //   change. Otherwise if the stored shape is unknown, the new shape will be
    235   //   <shape>.
    236   // - If both shapes are known, then they must have the same rank.
    237   // - For any one dimension, if the values for that dimension in both shapes
    238   //   are known, then the values must match.
    239   // - If one shape has equal or more information than the other shape in every
    240   //   dimension, the new shape will become the shape with more information.
    241   // - Example: merging [2,?] and [?,2] results in [2,2]
    242   // - Example: [2,2] cannot be merged with [1,2]
    243   //
    244   // This requires idx to be in the [0, num_inputs) range. If the merge is
    245   // successful, return true. Return false otherwise.
    246   bool MergeInput(int idx, ShapeHandle shape) {
    247     ShapeHandle new_shape;
    248     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
    249     inputs_[idx] = new_shape;
    250     return true;
    251   }
    252 
    253   // Relax the stored shape of the input in position idx with <shape> according
    254   // to the following rules:
    255   //
    256   // - If the ShapeHandles are the same then the stored shape will be returned.
    257   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
    258   //   be returned. A new shape must be returned because we cannot claim that
    259   //   the resulting shape is necessarily the same as either of the input
    260   //   shapes.
    261   // - If the shapes both have known ranks but their ranks are different, a new
    262   //   UnknownShape will be returned.
    263   // - For any one dimension, if the value for that dimension in either of the
    264   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
    265   //   that dimension.
    266   // - For any one dimension, if the values for that dimension in both shapes
    267   //   are known but do not match, a new shape will be returned with a new
    268   //   UnknownDim in that dimension.
    269   // - If both shapes have the same known rank and match in every dimension,
    270   //   the stored shape will be returned.
    271   // - Example: relaxing [2,?] and [?,2] results in [?,?]
    272   // - Example: relaxing [2,2] and [3,2] results in [?,2]
    273   // - Example: relaxing [2,2] with [1,2,3] results in ?
    274   //
    275   // This requires idx to be in the [0, num_inputs) range. If the relax is
    276   // successful and the new shape differs from the old one, store the new
    277   // shape and return true. Return false otherwise.
    278   bool RelaxInput(int idx, ShapeHandle shape) {
    279     ShapeHandle new_shape;
    280     Relax(inputs_[idx], shape, &new_shape);
    281     if (inputs_[idx].SameHandle(new_shape)) {
    282       return false;
    283     }
    284     inputs_[idx] = new_shape;
    285     return true;
    286   }
    287 
    288   ShapeHandle input(int64 idx) const { return inputs_[idx]; }
    289   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
    290   int num_inputs() const { return inputs_.size(); }
    291 
    292   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
    293   // not available at the time of shape inference.
    294   const Tensor* input_tensor(int idx) {
    295     // Mark that this idx was requested.
    296     requested_input_tensor_[idx] = true;
    297     return input_tensors_[idx];
    298   }
    299 
    300   // Returns true iff input_tensor(idx) was called by the shape function.
    301   bool requested_input_tensor(int idx) const {
    302     return requested_input_tensor_[idx];
    303   }
    304 
    305   // Returns true if MakeShapeFromInputTensor was called but the constant
    306   // input_tensor was not present.
    307   bool requested_input_tensor_as_partial_shape(int idx) const {
    308     return requested_input_tensor_as_partial_shape_[idx];
    309   }
    310 
    311   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
    312     input_tensors_ = input_tensors;
    313   }
    314 
    315   void set_input_tensors_as_shapes(
    316       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
    317     input_tensors_as_shapes_ = input_tensors_as_shapes;
    318   }
    319 
    320   void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
    321   Status set_output(StringPiece output_name,
    322                     const std::vector<ShapeHandle>& shapes);
    323 
    324   int num_outputs() const { return outputs_.size(); }
    325   ShapeHandle output(int idx) const { return outputs_[idx]; }
    326   Status output(StringPiece output_name,
    327                 std::vector<ShapeHandle>* output) const;
    328 
    329   AttrSlice attrs() const { return AttrSlice(*node_def_); }
    330 
    331   string op() const;
    332 
    333   // idx can be negative for an offset from end of dimensions.
    334   // idx must be in the range [-1 * s.rank, s.rank).
    335   DimensionHandle Dim(ShapeHandle s, int64 idx) {
    336     if (s->rank_ == kUnknownRank) {
    337       return UnknownDim();
    338     }
    339     return DimKnownRank(s, idx);
    340   }
    341   // As above, but asserts that the rank of the shape is known.
    342   static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
    343     CHECK_NE(s->rank_, kUnknownRank);
    344     if (idx < 0) {
    345       return s->dims_[s->dims_.size() + idx];
    346     }
    347     return s->dims_[idx];
    348   }
    349 
    350   static int32 Rank(ShapeHandle s) {
    351     DCHECK(s.IsSet());
    352     return s.IsSet() ? s->rank_ : kUnknownRank;
    353   }
    354   static bool RankKnown(ShapeHandle s) {
    355     return (s.IsSet() && (Rank(s) != kUnknownRank));
    356   }
    357   static inline int64 Value(DimensionOrConstant d) {
    358     return d.dim.IsSet() ? d.dim->value_ : d.val;
    359   }
    360   static inline bool ValueKnown(DimensionOrConstant d) {
    361     return Value(d) != kUnknownDim;
    362   }
    363 
    364   // Fills the output proto with the shape defined by the handle.
    365   // "proto" is expected to be empty prior to the call.
    366   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
    367 
    368   // Returns true if the rank and all dimensions of the Shape are known.
    369   bool FullyDefined(ShapeHandle s);
    370 
    371   // Returns the total number of elements, or an unknown dimension for an
    372   // incomplete shape.
    373   DimensionHandle NumElements(ShapeHandle s);
    374 
    375   string DebugString(ShapeHandle s);
    376   string DebugString(DimensionHandle d);
    377 
    378   // Describes the whole context, for debugging purposes.
    379   string DebugString() const;
    380 
    381   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
    382   // the shape with asserted rank in <*out>. Otherwise return an error.
    383   //
    384   // Note that <*out> may be set to <shape>.
    385   Status WithRank(ShapeHandle shape, int64 rank,
    386                   ShapeHandle* out) TF_MUST_USE_RESULT;
    387   Status WithRankAtLeast(ShapeHandle shape, int64 rank,
    388                          ShapeHandle* out) TF_MUST_USE_RESULT;
    389   Status WithRankAtMost(ShapeHandle shape, int64 rank,
    390                         ShapeHandle* out) TF_MUST_USE_RESULT;
    391 
    392   // If <dim> has value <value>, or its value is unknown, returns OK and returns
    393   // the dimension with asserted value in <*out>. Otherwise returns an error.
    394   //
    395   // Note that <*out> may be set to <dim>.
    396   Status WithValue(DimensionHandle dim, int64 value,
    397                    DimensionHandle* out) TF_MUST_USE_RESULT;
    398 
    399   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
    400   // 'MergeInput' function for full details and examples.
    401   Status Merge(ShapeHandle s0, ShapeHandle s1,
    402                ShapeHandle* out) TF_MUST_USE_RESULT;
    403 
    404   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
    405   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
    406   // <prefix>.
    407   // Returns the merged results in <*s_out> and <*prefix_out>.
    408   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
    409                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
    410 
    411   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
    412   // and <d1> have incompatible values, returns an error.
    413   //
    414   // Note that <*out> may be set to <d0> or <d1>.
    415   Status Merge(DimensionHandle d0, DimensionHandle d1,
    416                DimensionHandle* out) TF_MUST_USE_RESULT;
    417 
    418   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
    419   // <start> can be negative to index from the end of the shape. If <start> >
    420   // rank of <s>, then an empty subshape is returned.
    421   Status Subshape(ShapeHandle s, int64 start,
    422                   ShapeHandle* out) TF_MUST_USE_RESULT;
    423 
    424   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
    425   // <start> and <end> can be negative, to index from the end of the shape.
    426   // <start> and <end> are set to the rank of <s> if > rank of <s>.
    427   Status Subshape(ShapeHandle s, int64 start, int64 end,
    428                   ShapeHandle* out) TF_MUST_USE_RESULT;
    429 
    430   // Returns in <*out> the result of appending the dimensions of <s2> to those
    431   // of <s1>.
    432   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
    433                      ShapeHandle* out) TF_MUST_USE_RESULT;
    434 
    435   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
    436   // <new_dim>.
    437   Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
    438                     ShapeHandle* out) TF_MUST_USE_RESULT;
    439 
    440   // Returns a new shape with the given dims. The returned value is owned by
    441   // this context.
    442   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
    443   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
    444 
    445   // Returns a new unknown shape.
    446   ShapeHandle UnknownShape();
    447 
    448   // Returns a shape with specified rank but unknown dims.
    449   ShapeHandle UnknownShapeOfRank(int64 rank);
    450 
    451   // Returns a new shape of zero dimensions.
    452   ShapeHandle Scalar();
    453 
    454   // Returns a new shape of one dimension.
    455   ShapeHandle Vector(DimensionOrConstant dim);
    456 
    457   // Returns a new shape of two dimensions.
    458   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
    459 
    460   // Returns in <out> a new shape whose dimension sizes come from input tensor
    461   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
    462   // the input tensor is NULL, then an unknown shape is returned.
    463   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
    464 
    465   // Returns in <out> a new shape corresponding to <proto>.
    466   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
    467                                  ShapeHandle* out);
    468 
    469   // Returns in <out> a new shape corresponding to <partial_shape>.
    470   Status MakeShapeFromPartialTensorShape(
    471       const PartialTensorShape& partial_shape, ShapeHandle* out);
    472 
    473   // Returns in <out> a new shape corresponding to <shape>.
    474   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
    475 
    476   // Returns a new dimension of the given size.  The returned value is owned by
    477   // this context.
    478   inline DimensionHandle MakeDim(DimensionOrConstant d) {
    479     return shape_manager_.MakeDim(d);
    480   }
    481 
    482   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
    483 
    484   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
    485   // must be a 1-dimensional int32 or int64 tensor.  Caller must ensure that the
    486   // input tensor is not NULL.
    487   Status GetScalarFromTensor(const Tensor* t, int64* val);
    488 
    489   // Returns a new dimension whose value is given by a scalar input tensor.
    490   // The input tensor must be in host memory, since it is dereferenced to get
    491   // the value.
    492   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
    493 
    494   // Returns a new dimension whose value is given by a scalar input tensor.
    495   // This allows for a negative input dimension given the rank of a separate
    496   // tensor.  This rank can be negative if unknown.
    497   // The input tensor must be in host memory, since it is dereferenced to get
    498   // the value.
    499   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
    500                                                    DimensionHandle* out);
    501 
    502   // Look up the attr for the NodeDef being evaluated with name attr_name and
    503   // set *value to its value.  If no attr with attr_name is found in def(), or
    504   // the attr does not have a matching type, a non-ok status will be returned.
    505   template <class T>
    506   Status GetAttr(StringPiece attr_name, T* value) const;
    507 
    508   // Returns in <out> the result of dividing <dividend> by <divisor>.
    509   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
    510   // and <divisor> does not evenly divide <dividend>.
    511   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
    512                 bool evenly_divisible, DimensionHandle* out);
    513 
    514   // Returns in <out> the sum of <first> and <second>.
    515   Status Add(DimensionHandle first, DimensionOrConstant second,
    516              DimensionHandle* out);
    517 
    518   // Returns in <out> the dimension that is <first> minus <second>.
    519   Status Subtract(DimensionHandle first, DimensionOrConstant second,
    520                   DimensionHandle* out);
    521 
    522   // Returns in <out> the product of <first> and <second>.
    523   Status Multiply(DimensionHandle first, DimensionOrConstant second,
    524                   DimensionHandle* out);
    525 
    526   // Returns in <out> the minimum of <first> and <second>. If either <first> or
    527   // <second> is zero the results is zero. Otherwise, if either <first> or
    528   // <second> is unknown the results is unknown.
    529   Status Min(DimensionHandle first, DimensionOrConstant second,
    530              DimensionHandle* out);
    531 
    532   // Returns in <out> the maximum of <first> and <second>. If either <first> or
    533   // <second> is unknown the results is unknown.
    534   Status Max(DimensionHandle first, DimensionOrConstant second,
    535              DimensionHandle* out);
    536 
    537   Status construction_status() const { return construction_status_; }
    538 
    539   // Methods to propagate shape and dtype on edges of handles. Handles are the
    540   // dtype DT_RESOURCE which can be used to access state stored in a
    541   // ResourceManager. When ops (such as variables) consume these handles to
    542   // produce tensors they might need to know side-information about the shapes
    543   // and dtypes of tensors which can be accessed via the handle. These methods
    544   // propagate that information. Output handle dtypes and shapes are ignored if
    545   // the output tensor is not of type DT_RESOURCE.
    546 
    547   // Merge the stored shapes and types corresponding to the input handle in
    548   // position idx with the specified shapes and types. This requires idx to be
    549   // in the [0, num_inputs) range.
    550   //
    551   // If the merge is successful and any of the new shapes differs from the old
    552   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
    553   // return true.  Return false otherwise.
    554   //
    555   // See 'MergeInput' function for full details and examples.
    556   bool MergeInputHandleShapesAndTypes(
    557       int idx,
    558       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    559 
    560   // As MergeInputHandleShapesAndTypes, but for an output.
    561   bool MergeOutputHandleShapesAndTypes(
    562       int idx,
    563       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    564 
    565   // Relaxes the stored shapes and types corresponding to the input handle in
    566   // position idx with the specified shapes and types. This requires idx to be
    567   // in the [0, num_inputs) range.
    568   //
    569   // If the relax is successful and any of the new shapes differs from the old
    570   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
    571   // return true.  Return false otherwise.
    572   //
    573   // See 'RelaxInput' function for full details and examples.
    574   bool RelaxInputHandleShapesAndMergeTypes(
    575       int idx,
    576       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    577 
    578   // As RelaxInputHandleShapesAndTypes, but for an output.
    579   bool RelaxOutputHandleShapesAndMergeTypes(
    580       int idx,
    581       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    582 
    583   // Returns the output handle shapes and types, for the resource tensor output
    584   // at index <idx>. Returns NULL if the shape and types were never set.
    585   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
    586     return output_handle_shapes_and_types_[idx].get();
    587   }
    588 
    589   // Returns the inputs handle shapes and types, for the resource tensor output
    590   // at index <idx>. Returns NULL if the shape and types were not available.
    591   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
    592     return input_handle_shapes_and_types_[idx].get();
    593   }
    594 
    595   void set_output_handle_shapes_and_types(
    596       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
    597     output_handle_shapes_and_types_[idx].reset(
    598         new std::vector<ShapeAndType>(shapes_and_types));
    599   }
    600 
    601   // Note that shape functions should usually call MakeShapeFromShapeTensor,
    602   // as it does more analysis to provide partial shapes.
    603   //
    604   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
    605   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
    606   // then an unknown shape is returned.
    607   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
    608                              ShapeHandle* out);
    609 
    610   int graph_def_version() const { return graph_def_version_; }
    611 
    612   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
    613     return merged_shapes_;
    614   }
    615   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
    616       const {
    617     return merged_dims_;
    618   }
    619 
    620  private:
    621   // Creates and stores shapes for use in InferenceContext.
    622   class ShapeManager {
    623    public:
    624     ShapeManager();
    625     ~ShapeManager();
    626 
    627     // Returns a new shape with the given dims. The returned value is owned by
    628     // this class.
    629     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
    630 
    631     // Returns a new unknown shape.
    632     ShapeHandle UnknownShape();
    633 
    634     // Returns a new dimension of the given size.  The returned value
    635     // is owned by this class.
    636     inline DimensionHandle MakeDim(DimensionOrConstant d) {
    637       if (d.dim.IsSet()) {
    638         return d.dim;
    639       } else {
    640         all_dims_.push_back(new Dimension(d.val));
    641         return all_dims_.back();
    642       }
    643     }
    644 
    645    private:
    646     std::vector<Shape*> all_shapes_;    // values are owned.
    647     std::vector<Dimension*> all_dims_;  // values are owned.
    648   };
    649 
    650   friend class ::tensorflow::grappler::GraphProperties;
    651 
    652   // Friend for user-defined function shape inference purposes.
    653   friend class ::tensorflow::ShapeRefiner;
    654 
    655   friend class ShapeInferenceTest;      // For testing Relax functions.
    656   friend class ShapeInferenceTestutil;  // For testing shapes.
    657 
    658   // Shared initialization across the two constructors.  Remove
    659   // once we get rid of one of them.
    660   void PreInputInit(const OpDef& op_def,
    661                     const std::vector<const Tensor*>& input_tensors,
    662                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
    663   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    664                          input_handle_data);
    665 
    666   DimensionHandle GetDimension(const DimensionOrConstant& d);
    667 
    668   Status ReturnUnknownShape(ShapeHandle* out) {
    669     *out = UnknownShape();
    670     return Status::OK();
    671   }
    672   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
    673                             ShapeHandle* out) {
    674     *out = MakeShape(dims);
    675     return Status::OK();
    676   }
    677 
    678   // Adds additional context to the given status.
    679   Status AttachContext(const Status& status);
    680 
    681   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
    682   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
    683   // values, returns an error.
    684   //
    685   // Note that <*out> may be set to <d_old> or <d_new>.
    686   void Relax(DimensionHandle d_old, DimensionHandle d_new,
    687              DimensionHandle* out);
    688   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
    689   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
    690   // examples.
    691   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
    692 
    693   // Used to implement MergeInputHandleShapesAndTypes and
    694   // MergeOutputHandleShapesAndTypes.
    695   bool MergeHandleShapesAndTypes(
    696       const std::vector<ShapeAndType>& shapes_and_types,
    697       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
    698   // Used to implement RelaxInputHandleShapesAndMergeTypes and
    699   // RelaxOutputHandleShapesAndMergeTypes.
    700   bool RelaxHandleShapesAndMergeTypes(
    701       const std::vector<ShapeAndType>& shapes_and_types,
    702       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
    703 
    704   // Forget all the previous merged shapes and dims.
    705   void ForgetMerges() {
    706     merged_shapes_.clear();
    707     merged_dims_.clear();
    708   }
    709 
    710   ShapeManager shape_manager_;
    711 
    712   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
    713   // `shape_manager_`.
    714   std::vector<ShapeHandle> inputs_;
    715   std::vector<const Tensor*> input_tensors_;
    716   std::vector<bool> requested_input_tensor_;
    717   std::vector<ShapeHandle> outputs_;
    718   // Can have fewer elements than inputs_.
    719   std::vector<ShapeHandle> input_tensors_as_shapes_;
    720   std::vector<bool> requested_input_tensor_as_partial_shape_;
    721 
    722   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
    723   // through the resource handle passed along input i of the node.
    724   //
    725   // Values may be NULL.
    726   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    727       input_handle_shapes_and_types_;
    728 
    729   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
    730   // available through the resource handle passed along output i of the node.
    731   //
    732   // Values may be NULL.
    733   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    734       output_handle_shapes_and_types_;
    735 
    736   const int graph_def_version_;
    737   const NodeDef* node_def_;
    738   NameRangeMap input_name_map_;
    739   NameRangeMap output_name_map_;
    740 
    741   // An error set during construction. TODO(cwhipkey): remove when test
    742   // constructor is removed.
    743   Status construction_status_;
    744 
    745   // Pair of shape or dim handles that are equivalent, ie that represent the
    746   // same underlying shape of dimension. Note that for each pair at least one of
    747   // the handles must contain an unknown shape, since we don't keep track of
    748   // known shapes or dims here.
    749   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
    750   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
    751 
    752   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
    753 };
    754 
    755 // -----------------------------------------------------------------------------
    756 // Template and inline method implementations, please ignore
    757 
    758 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
    759 inline Dimension::Dimension(int64 value) : value_(value) {
    760   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
    761       << "Dimension must be non-negative or equal to "
    762          "InferenceContext::kUnknownDim but got "
    763       << value;
    764 }
    765 
    766 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
    767 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
    768     : rank_(dims.size()), dims_(dims) {}
    769 
    770 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
    771     : dim(dim) {
    772   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
    773 }
    774 
    775 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
    776   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
    777       << "Dimension must be non-negative or equal to "
    778          "InferenceContext::kUnknownDim but got "
    779       << val;
    780 }
    781 
    782 template <class T>
    783 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
    784   return GetNodeAttr(*node_def_, attr_name, value);
    785 }
    786 
    787 }  // namespace shape_inference
    788 }  // namespace tensorflow
    789 
    790 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
    791