Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
     16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/node_def_util.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 
     27 namespace tensorflow {
     28 
     29 class ShapeRefiner;
     30 class ShapeRefinerTest;
     31 
     32 namespace grappler {
     33 class GraphProperties;
     34 class SymbolicShapeManager;
     35 }  // namespace grappler
     36 
     37 namespace shape_inference {
     38 
     39 struct DimensionOrConstant;
     40 class InferenceContext;
     41 
     42 // Dimension values are accessed through InferenceContext.
     43 class Dimension {
     44  private:
     45   Dimension();
     46   Dimension(int64 value);
     47   ~Dimension() {}
     48 
     49   const int64 value_;
     50 
     51   friend class InferenceContext;
     52   friend class ShapeManager;
     53   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
     54 };
     55 
     56 class DimensionHandle {
     57  public:
     58   DimensionHandle() {}
     59   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
     60   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
     61 
     62  private:
     63   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
     64 
     65   const Dimension* operator->() const { return ptr_; }
     66   bool IsSet() const { return ptr_ != nullptr; }
     67 
     68   const Dimension* ptr_ = nullptr;
     69 
     70   friend struct DimensionOrConstant;
     71   friend class InferenceContext;
     72   friend class ShapeInferenceTest;
     73   friend class ShapeInferenceTestutil;
     74   friend class ::tensorflow::ShapeRefinerTest;
     75   friend class ShapeManager;
     76   friend class ::tensorflow::grappler::GraphProperties;
     77   friend class ::tensorflow::grappler::SymbolicShapeManager;
     78 
     79   // Intentionally copyable.
     80 };
     81 
     82 // Shape rank and dimensions are accessed through InferenceContext.
     83 class Shape {
     84  private:
     85   Shape();
     86   Shape(const std::vector<DimensionHandle>& dims);
     87   ~Shape() {}
     88 
     89   const int32 rank_;
     90   const std::vector<DimensionHandle> dims_;
     91 
     92   friend class InferenceContext;
     93   friend class ShapeManager;
     94   friend class ::tensorflow::grappler::SymbolicShapeManager;
     95 
     96   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
     97 };
     98 
     99 class ShapeHandle {
    100  public:
    101   ShapeHandle() {}
    102   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
    103   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
    104 
    105  private:
    106   ShapeHandle(const Shape* shape) { ptr_ = shape; }
    107   const Shape* operator->() const { return ptr_; }
    108   bool IsSet() const { return ptr_ != nullptr; }
    109 
    110   const Shape* ptr_ = nullptr;
    111 
    112   friend class InferenceContext;
    113   friend class ShapeInferenceTest;
    114   friend class ShapeInferenceTestutil;
    115   friend class ::tensorflow::ShapeRefinerTest;
    116   friend class ShapeManager;
    117   friend class ::tensorflow::grappler::SymbolicShapeManager;
    118 
    119   // Intentionally copyable.
    120 };
    121 
    122 // Struct used to allow functions to take DimensionHandle or a dimension value.
    123 // Not meant to be constructed directly.
    124 struct DimensionOrConstant {
    125  public:
    126   // Intentionally not explicit.
    127   DimensionOrConstant(DimensionHandle dim);
    128 
    129   // val must be non-negative or InferenceContext::kUnknownDim.
    130   DimensionOrConstant(int64 val);
    131 
    132   // dim takes precedence. If dim != nullptr, val is ignored.
    133   DimensionHandle dim;
    134   int64 val;
    135 
    136  private:
    137   DimensionOrConstant();
    138 };
    139 
    140 struct ShapeAndType {
    141   ShapeAndType() {}
    142   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
    143 
    144   ShapeHandle shape;
    145   DataType dtype = DT_INVALID;
    146 };
    147 
    148 // Shape inference functions registered on ops in REGISTER_OP implement
    149 // their shape functions in terms of this InferenceContext.  An InferenceContext
    150 // is created by the framework and passed to a shape inference function.  The
    151 // shape inference function calls functions on the context, and should call
    152 // set_output() to set the shape on all outputs.
    153 //
    154 // To infer shapes for user-defined functions see ShapeRefiner.
    155 //
    156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
    157 // by the InferenceContext.
    158 class InferenceContext {
    159  public:
    160   static constexpr int64 kUnknownDim = -1;
    161   static constexpr int32 kUnknownRank = -1;
    162 
    163   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    164   //
    165   // Elements of <input_tensors_as_shapes> are used for when a shape function
    166   // makes a call to MakeShapeFromShapeTensor; in particular, when the
    167   // input_tensors[i] is nullptr but the shape represented by it is partially
    168   // known from analysis of the graph.
    169   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
    170   // Values of <input_tensors_as_shapes> do not need to outlive the context.
    171   //
    172   // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
    173   InferenceContext(int graph_def_version, const NodeDef* node_def,
    174                    const OpDef& op_def,
    175                    const std::vector<ShapeHandle>& input_shapes,
    176                    const std::vector<const Tensor*>& input_tensors,
    177                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
    178                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    179                        input_handle_shapes_and_types);
    180 
    181   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    182   //
    183   // Elements of <input_tensors_as_shapes> are used for when a shape
    184   // function makes a call to MakeShapeFromShapeTensor; in particular, when
    185   // the input_tensors[i] is nullptr but the shape represented by it is
    186   // partially known from analysis of the graph. <input_tensors_as_shapes>
    187   // can have fewer elements than <input_shapes>. Values of
    188   // <input_tensors_as_shapes> do not need to outlive the context.
    189   //
    190   // REQUIRES: <node_def> is not NULL, and must outlive the
    191   // InferenceContext.
    192   InferenceContext(
    193       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
    194       const std::vector<TensorShapeProto>& input_shapes,
    195       const std::vector<const Tensor*>& input_tensors,
    196       const std::vector<TensorShapeProto>& input_tensors_as_shapes,
    197       const std::vector<
    198           std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
    199           input_handle_shapes_and_types);
    200 
    201   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
    202   //
    203   // Elements of <input_tensors_as_shapes> are used for when a shape
    204   // function makes a call to MakeShapeFromShapeTensor; in particular, when
    205   // the input_tensors[i] is nullptr but the shape represented by it is
    206   // partially known from analysis of the graph. <input_tensors_as_shapes>
    207   // can have fewer elements than <input_shapes>. Values of
    208   // <input_tensors_as_shapes> do not need to outlive the context.
    209   //
    210   // REQUIRES: <node_def> is not NULL, and must outlive the
    211   // InferenceContext.
    212   InferenceContext(
    213       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
    214       const std::vector<PartialTensorShape>& input_shapes,
    215       const std::vector<const Tensor*>& input_tensors,
    216       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
    217       const std::vector<std::unique_ptr<
    218           std::vector<std::pair<PartialTensorShape, DataType>>>>&
    219           input_handle_shapes_and_types);
    220 
    221   ~InferenceContext();
    222 
    223   // Runs the shape inference function 'fn' with 'this' as the
    224   // argument, returns the status of the inference.
    225   //
    226   // On error, additional context is provided in the error message.
    227   Status Run(
    228       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
    229 
    230   // Merge the stored shape of the input in position idx with <shape> according
    231   // to the following rules:
    232   //
    233   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
    234   //   change. Otherwise if the stored shape is unknown, the new shape will be
    235   //   <shape>.
    236   // - If both shapes are known, then they must have the same rank.
    237   // - For any one dimension, if the values for that dimension in both shapes
    238   //   are known, then the values must match.
    239   // - If one shape has equal or more information than the other shape in every
    240   //   dimension, the new shape will become the shape with more information.
    241   // - Example: merging [2,?] and [?,2] results in [2,2]
    242   // - Example: [2,2] cannot be merged with [1,2]
    243   //
    244   // This requires idx to be in the [0, num_inputs) range. If the merge is
    245   // successful, return true. Return false otherwise.
    246   bool MergeInput(int idx, ShapeHandle shape) {
    247     ShapeHandle new_shape;
    248     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
    249     inputs_[idx] = new_shape;
    250     return true;
    251   }
    252 
    253   // Relax the stored shape of the input in position idx with <shape> according
    254   // to the following rules:
    255   //
    256   // - If the ShapeHandles are the same then the stored shape will be returned.
    257   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
    258   //   be returned. A new shape must be returned because we cannot claim that
    259   //   the resulting shape is necessarily the same as either of the input
    260   //   shapes.
    261   // - If the shapes both have known ranks but their ranks are different, a new
    262   //   UnknownShape will be returned.
    263   // - For any one dimension, if the value for that dimension in either of the
    264   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
    265   //   that dimension.
    266   // - For any one dimension, if the values for that dimension in both shapes
    267   //   are known but do not match, a new shape will be returned with a new
    268   //   UnknownDim in that dimension.
    269   // - If both shapes have the same known rank and match in every dimension,
    270   //   the stored shape will be returned.
    271   // - Example: relaxing [2,?] and [?,2] results in [?,?]
    272   // - Example: relaxing [2,2] and [3,2] results in [?,2]
    273   // - Example: relaxing [2,2] with [1,2,3] results in ?
    274   //
    275   // This requires idx to be in the [0, num_inputs) range. If the relax is
    276   // successful and the new shape differs from the old one, store the new
    277   // shape and return true. Return false otherwise.
    278   bool RelaxInput(int idx, ShapeHandle shape) {
    279     ShapeHandle new_shape;
    280     Relax(inputs_[idx], shape, &new_shape);
    281     if (inputs_[idx].SameHandle(new_shape)) {
    282       return false;
    283     }
    284     inputs_[idx] = new_shape;
    285     return true;
    286   }
    287 
    288   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
    289 
    290   ShapeHandle input(int64 idx) const { return inputs_[idx]; }
    291   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
    292   int num_inputs() const { return inputs_.size(); }
    293 
    294   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
    295   // not available at the time of shape inference.
    296   const Tensor* input_tensor(int idx) {
    297     // Mark that this idx was requested.
    298     requested_input_tensor_[idx] = true;
    299     return input_tensors_[idx];
    300   }
    301 
    302   // Returns true iff input_tensor(idx) was called by the shape function.
    303   bool requested_input_tensor(int idx) const {
    304     return requested_input_tensor_[idx];
    305   }
    306 
    307   // Returns true if MakeShapeFromInputTensor was called but the constant
    308   // input_tensor was not present.
    309   bool requested_input_tensor_as_partial_shape(int idx) const {
    310     return requested_input_tensor_as_partial_shape_[idx];
    311   }
    312 
    313   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
    314     input_tensors_ = input_tensors;
    315   }
    316 
    317   void set_input_tensors_as_shapes(
    318       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
    319     input_tensors_as_shapes_ = input_tensors_as_shapes;
    320   }
    321 
    322   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
    323     return input_tensors_as_shapes_;
    324   }
    325 
    326   ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
    327   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
    328   Status set_output(StringPiece output_name,
    329                     const std::vector<ShapeHandle>& shapes);
    330 
    331   int num_outputs() const { return outputs_.size(); }
    332   ShapeHandle output(int idx) const { return outputs_.at(idx); }
    333   Status output(StringPiece output_name,
    334                 std::vector<ShapeHandle>* output) const;
    335 
    336   AttrSlice attrs() const { return AttrSlice(*node_def_); }
    337 
    338   string op() const;
    339 
    340   // idx can be negative for an offset from end of dimensions.
    341   // idx must be in the range [-1 * s.rank, s.rank).
    342   DimensionHandle Dim(ShapeHandle s, int64 idx) {
    343     if (s->rank_ == kUnknownRank) {
    344       return UnknownDim();
    345     }
    346     return DimKnownRank(s, idx);
    347   }
    348   // As above, but asserts that the rank of the shape is known.
    349   static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
    350     CHECK_NE(s->rank_, kUnknownRank);
    351     if (idx < 0) {
    352       return s->dims_[s->dims_.size() + idx];
    353     }
    354     return s->dims_[idx];
    355   }
    356 
    357   static int32 Rank(ShapeHandle s) {
    358     DCHECK(s.IsSet());
    359     return s.IsSet() ? s->rank_ : kUnknownRank;
    360   }
    361   static bool RankKnown(ShapeHandle s) {
    362     return (s.IsSet() && (Rank(s) != kUnknownRank));
    363   }
    364   static inline int64 Value(DimensionOrConstant d) {
    365     return d.dim.IsSet() ? d.dim->value_ : d.val;
    366   }
    367   static inline bool ValueKnown(DimensionOrConstant d) {
    368     return Value(d) != kUnknownDim;
    369   }
    370 
    371   // Fills the output proto with the shape defined by the handle.
    372   // "proto" is expected to be empty prior to the call.
    373   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
    374 
    375   // Returns true if the rank and all dimensions of the Shape are known.
    376   bool FullyDefined(ShapeHandle s);
    377 
    378   // Returns the total number of elements, or an unknown dimension for an
    379   // incomplete shape.
    380   DimensionHandle NumElements(ShapeHandle s);
    381 
    382   string DebugString(ShapeHandle s);
    383   string DebugString(DimensionHandle d);
    384   string DebugString(const ShapeAndType& shape_and_type);
    385   string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
    386 
    387   // Describes the whole context, for debugging purposes.
    388   string DebugString() const;
    389 
    390   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
    391   // the shape with asserted rank in <*out>. Otherwise return an error.
    392   //
    393   // Note that <*out> may be set to <shape>.
    394   Status WithRank(ShapeHandle shape, int64 rank,
    395                   ShapeHandle* out) TF_MUST_USE_RESULT;
    396   Status WithRankAtLeast(ShapeHandle shape, int64 rank,
    397                          ShapeHandle* out) TF_MUST_USE_RESULT;
    398   Status WithRankAtMost(ShapeHandle shape, int64 rank,
    399                         ShapeHandle* out) TF_MUST_USE_RESULT;
    400 
    401   // If <dim> has value <value>, or its value is unknown, returns OK and returns
    402   // the dimension with asserted value in <*out>. Otherwise returns an error.
    403   //
    404   // Note that <*out> may be set to <dim>.
    405   Status WithValue(DimensionHandle dim, int64 value,
    406                    DimensionHandle* out) TF_MUST_USE_RESULT;
    407 
    408   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
    409   // 'MergeInput' function for full details and examples.
    410   Status Merge(ShapeHandle s0, ShapeHandle s1,
    411                ShapeHandle* out) TF_MUST_USE_RESULT;
    412 
    413   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
    414   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
    415   // <prefix>.
    416   // Returns the merged results in <*s_out> and <*prefix_out>.
    417   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
    418                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
    419 
    420   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
    421   // and <d1> have incompatible values, returns an error.
    422   //
    423   // Note that <*out> may be set to <d0> or <d1>.
    424   Status Merge(DimensionHandle d0, DimensionHandle d1,
    425                DimensionHandle* out) TF_MUST_USE_RESULT;
    426 
    427   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
    428   // <start> can be negative to index from the end of the shape. If <start> >
    429   // rank of <s>, then an empty subshape is returned.
    430   Status Subshape(ShapeHandle s, int64 start,
    431                   ShapeHandle* out) TF_MUST_USE_RESULT;
    432 
    433   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
    434   // <start> and <end> can be negative, to index from the end of the shape.
    435   // <start> and <end> are set to the rank of <s> if > rank of <s>.
    436   Status Subshape(ShapeHandle s, int64 start, int64 end,
    437                   ShapeHandle* out) TF_MUST_USE_RESULT;
    438 
    439   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
    440   // <start> and <end> can be negative, to index from the end of the shape.
    441   // <start> and <end> are set to the rank of <s> if > rank of <s>.
    442   // <stride> can be negative, to reverse the <s>.
    443   Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
    444                   ShapeHandle* out) TF_MUST_USE_RESULT;
    445 
    446   // Returns in <*out> the result of appending the dimensions of <s2> to those
    447   // of <s1>.
    448   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
    449                      ShapeHandle* out) TF_MUST_USE_RESULT;
    450 
    451   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
    452   // <new_dim>.
    453   Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
    454                     ShapeHandle* out) TF_MUST_USE_RESULT;
    455 
    456   // Returns a new shape with the given dims. The returned value is owned by
    457   // this context.
    458   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
    459   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
    460 
    461   // Returns a new unknown shape.
    462   ShapeHandle UnknownShape();
    463 
    464   // Returns a shape with specified rank but unknown dims.
    465   ShapeHandle UnknownShapeOfRank(int64 rank);
    466 
    467   // Returns a new shape of zero dimensions.
    468   ShapeHandle Scalar();
    469 
    470   // Returns a new shape of one dimension.
    471   ShapeHandle Vector(DimensionOrConstant dim);
    472 
    473   // Returns a new shape of two dimensions.
    474   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
    475 
    476   // Returns in <out> a new shape whose dimension sizes come from input tensor
    477   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
    478   // the input tensor is NULL, then an unknown shape is returned.
    479   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
    480 
    481   // Like the function above, but treats scalar values as unknown
    482   // shapes.  **NOTE** If the scalar is statically known, its value
    483   // must be -1 or an error is returned.
    484   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
    485                                                            ShapeHandle* out);
    486 
    487   // Returns in <out> a new shape corresponding to <proto>.
    488   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
    489                                  ShapeHandle* out);
    490 
    491   // Returns in <out> a new shape corresponding to <partial_shape>.
    492   Status MakeShapeFromPartialTensorShape(
    493       const PartialTensorShape& partial_shape, ShapeHandle* out);
    494 
    495   // Returns in <out> a new shape corresponding to <shape>.
    496   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
    497 
    498   // Returns a new dimension of the given size.  The returned value is owned by
    499   // this context.
    500   inline DimensionHandle MakeDim(DimensionOrConstant d) {
    501     return shape_manager_.MakeDim(d);
    502   }
    503 
    504   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
    505 
    506   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
    507   // must be a 1-dimensional int32 or int64 tensor.  Caller must ensure that the
    508   // input tensor is not NULL.
    509   Status GetScalarFromTensor(const Tensor* t, int64* val);
    510 
    511   // Returns a new dimension whose value is given by a scalar input tensor.
    512   // The input tensor must be in host memory, since it is dereferenced to get
    513   // the value.
    514   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
    515 
    516   // Returns a new dimension whose value is given by a scalar input tensor.
    517   // This allows for a negative input dimension given the rank of a separate
    518   // tensor.  This rank can be negative if unknown.
    519   // The input tensor must be in host memory, since it is dereferenced to get
    520   // the value.
    521   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
    522                                                    DimensionHandle* out);
    523 
    524   // Look up the attr for the NodeDef being evaluated with name attr_name and
    525   // set *value to its value.  If no attr with attr_name is found in def(), or
    526   // the attr does not have a matching type, a non-ok status will be returned.
    527   template <class T>
    528   Status GetAttr(StringPiece attr_name, T* value) const;
    529 
    530   // Returns in <out> the result of dividing <dividend> by <divisor>.
    531   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
    532   // and <divisor> does not evenly divide <dividend>.
    533   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
    534                 bool evenly_divisible, DimensionHandle* out);
    535 
    536   // Returns in <out> the sum of <first> and <second>.
    537   Status Add(DimensionHandle first, DimensionOrConstant second,
    538              DimensionHandle* out);
    539 
    540   // Returns in <out> the dimension that is <first> minus <second>.
    541   Status Subtract(DimensionHandle first, DimensionOrConstant second,
    542                   DimensionHandle* out);
    543 
    544   // Returns in <out> the product of <first> and <second>.
    545   Status Multiply(DimensionHandle first, DimensionOrConstant second,
    546                   DimensionHandle* out);
    547 
    548   // Returns in <out> the minimum of <first> and <second>. If either <first> or
    549   // <second> is zero the results is zero. Otherwise, if either <first> or
    550   // <second> is unknown the results is unknown.
    551   Status Min(DimensionHandle first, DimensionOrConstant second,
    552              DimensionHandle* out);
    553 
    554   // Returns in <out> the maximum of <first> and <second>. If either <first> or
    555   // <second> is unknown the results is unknown.
    556   Status Max(DimensionHandle first, DimensionOrConstant second,
    557              DimensionHandle* out);
    558 
    559   Status construction_status() const { return construction_status_; }
    560 
    561   // Methods to propagate shape and dtype on edges of handles. Handles are the
    562   // dtype DT_RESOURCE which can be used to access state stored in a
    563   // ResourceManager. When ops (such as variables) consume these handles to
    564   // produce tensors they might need to know side-information about the shapes
    565   // and dtypes of tensors which can be accessed via the handle. These methods
    566   // propagate that information. Output handle dtypes and shapes are ignored if
    567   // the output tensor is not of type DT_RESOURCE.
    568 
    569   // Merge the stored shapes and types corresponding to the input handle in
    570   // position idx with the specified shapes and types. This requires idx to be
    571   // in the [0, num_inputs) range.
    572   //
    573   // If the merge is successful and any of the new shapes differs from the old
    574   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
    575   // return true.  Return false otherwise.
    576   //
    577   // See 'MergeInput' function for full details and examples.
    578   bool MergeInputHandleShapesAndTypes(
    579       int idx,
    580       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    581 
    582   // As MergeInputHandleShapesAndTypes, but for an output.
    583   bool MergeOutputHandleShapesAndTypes(
    584       int idx,
    585       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    586 
    587   // Relaxes the stored shapes and types corresponding to the input handle in
    588   // position idx with the specified shapes and types. This requires idx to be
    589   // in the [0, num_inputs) range.
    590   //
    591   // If the relax is successful (sizes are the same, old dtypes match new ones
    592   // or are DT_INVALID), then store the relaxed shapes and return true.
    593   // Return false otherwise.
    594   //
    595   // See 'RelaxInput' function for full details and examples.
    596   bool RelaxInputHandleShapesAndMergeTypes(
    597       int idx,
    598       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    599 
    600   // As RelaxInputHandleShapesAndTypes, but for an output.
    601   bool RelaxOutputHandleShapesAndMergeTypes(
    602       int idx,
    603       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
    604 
    605   void set_input_handle_shapes_and_types(
    606       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
    607     input_handle_shapes_and_types_[idx].reset(
    608         new std::vector<ShapeAndType>(shapes_and_types));
    609   }
    610 
    611   // Returns the output handle shapes and types, for the resource tensor output
    612   // at index <idx>. Returns NULL if the shape and types were never set.
    613   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
    614     return output_handle_shapes_and_types_[idx].get();
    615   }
    616 
    617   // Returns the inputs handle shapes and types, for the resource tensor output
    618   // at index <idx>. Returns NULL if the shape and types were not available.
    619   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
    620     return input_handle_shapes_and_types_[idx].get();
    621   }
    622 
    623   void set_output_handle_shapes_and_types(
    624       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
    625     output_handle_shapes_and_types_[idx].reset(
    626         new std::vector<ShapeAndType>(shapes_and_types));
    627   }
    628 
    629   // Note that shape functions should usually call MakeShapeFromShapeTensor,
    630   // as it does more analysis to provide partial shapes.
    631   //
    632   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
    633   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
    634   // then an unknown shape is returned.
    635   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
    636                              ShapeHandle* out);
    637 
    638   int graph_def_version() const { return graph_def_version_; }
    639 
    640   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
    641     return merged_shapes_;
    642   }
    643   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
    644       const {
    645     return merged_dims_;
    646   }
    647 
    648   // Adds new outputs; useful when mutating the graph.
    649   Status ExpandOutputs(int new_output_size);
    650 
    651  private:
    652   // Creates and stores shapes for use in InferenceContext.
    653   class ShapeManager {
    654    public:
    655     ShapeManager();
    656     ~ShapeManager();
    657 
    658     // Returns a new shape with the given dims. The returned value is owned by
    659     // this class.
    660     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
    661 
    662     // Returns a new unknown shape.
    663     ShapeHandle UnknownShape();
    664 
    665     // Returns a new dimension of the given size.  The returned value
    666     // is owned by this class.
    667     inline DimensionHandle MakeDim(DimensionOrConstant d) {
    668       if (d.dim.IsSet()) {
    669         return d.dim;
    670       } else {
    671         all_dims_.push_back(new Dimension(d.val));
    672         return all_dims_.back();
    673       }
    674     }
    675 
    676    private:
    677     std::vector<Shape*> all_shapes_;    // values are owned.
    678     std::vector<Dimension*> all_dims_;  // values are owned.
    679   };
    680 
    681   friend class ::tensorflow::grappler::GraphProperties;
    682 
    683   // Friend for user-defined function shape inference purposes.
    684   friend class ::tensorflow::ShapeRefiner;
    685 
    686   friend class ShapeInferenceTest;      // For testing Relax functions.
    687   friend class ShapeInferenceTestutil;  // For testing shapes.
    688 
    689   // Shared initialization across the two constructors.  Remove
    690   // once we get rid of one of them.
    691   void PreInputInit(const OpDef& op_def,
    692                     const std::vector<const Tensor*>& input_tensors,
    693                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
    694   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    695                          input_handle_data);
    696 
    697   DimensionHandle GetDimension(const DimensionOrConstant& d);
    698 
    699   Status ReturnUnknownShape(ShapeHandle* out) {
    700     *out = UnknownShape();
    701     return Status::OK();
    702   }
    703   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
    704                             ShapeHandle* out) {
    705     *out = MakeShape(dims);
    706     return Status::OK();
    707   }
    708 
    709   // Adds additional context to the given status.
    710   Status AttachContext(const Status& status);
    711 
    712   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
    713   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
    714   // values, returns an error.
    715   //
    716   // Note that <*out> may be set to <d_old> or <d_new>.
    717   void Relax(DimensionHandle d_old, DimensionHandle d_new,
    718              DimensionHandle* out);
    719   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
    720   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
    721   // examples.
    722   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
    723 
    724   // Used to implement MergeInputHandleShapesAndTypes and
    725   // MergeOutputHandleShapesAndTypes.
    726   bool MergeHandleShapesAndTypes(
    727       const std::vector<ShapeAndType>& shapes_and_types,
    728       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
    729   // Used to implement RelaxInputHandleShapesAndMergeTypes and
    730   // RelaxOutputHandleShapesAndMergeTypes.
    731   bool RelaxHandleShapesAndMergeTypes(
    732       const std::vector<ShapeAndType>& shapes_and_types,
    733       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
    734 
    735   // Forget all the previous merged shapes and dims.
    736   void ForgetMerges() {
    737     merged_shapes_.clear();
    738     merged_dims_.clear();
    739   }
    740 
    741   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
    742   Status InternalMakeShapeFromTensor(
    743       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
    744       ShapeHandle tensor_shape, ShapeHandle* out);
    745 
    746   ShapeManager shape_manager_;
    747 
    748   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
    749   // `shape_manager_`.
    750   std::vector<ShapeHandle> inputs_;
    751   std::vector<const Tensor*> input_tensors_;
    752   std::vector<bool> requested_input_tensor_;
    753   std::vector<ShapeHandle> outputs_;
    754   // Can have fewer elements than inputs_.
    755   std::vector<ShapeHandle> input_tensors_as_shapes_;
    756   std::vector<bool> requested_input_tensor_as_partial_shape_;
    757 
    758   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
    759   // through the resource handle passed along input i of the node.
    760   //
    761   // Values may be NULL.
    762   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    763       input_handle_shapes_and_types_;
    764 
    765   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
    766   // available through the resource handle passed along output i of the node.
    767   //
    768   // Values may be NULL.
    769   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
    770       output_handle_shapes_and_types_;
    771 
    772   const int graph_def_version_;
    773   const NodeDef* node_def_;
    774   NameRangeMap input_name_map_;
    775   NameRangeMap output_name_map_;
    776 
    777   // An error set during construction. TODO(cwhipkey): remove when test
    778   // constructor is removed.
    779   Status construction_status_;
    780 
    781   // Pair of shape or dim handles that are equivalent, ie that represent the
    782   // same underlying shape of dimension. Note that for each pair at least one of
    783   // the handles must contain an unknown shape, since we don't keep track of
    784   // known shapes or dims here.
    785   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
    786   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
    787 
    788   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
    789 };
    790 
    791 // -----------------------------------------------------------------------------
    792 // Template and inline method implementations, please ignore
    793 
    794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
    795 inline Dimension::Dimension(int64 value) : value_(value) {
    796   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
    797       << "Dimension must be non-negative or equal to "
    798          "InferenceContext::kUnknownDim but got "
    799       << value;
    800 }
    801 
    802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
    803 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
    804     : rank_(dims.size()), dims_(dims) {}
    805 
    806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
    807     : dim(dim) {
    808   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
    809 }
    810 
    811 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
    812   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
    813       << "Dimension must be non-negative or equal to "
    814          "InferenceContext::kUnknownDim but got "
    815       << val;
    816 }
    817 
    818 template <class T>
    819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
    820   return GetNodeAttr(*node_def_, attr_name, value);
    821 }
    822 
    823 }  // namespace shape_inference
    824 }  // namespace tensorflow
    825 
    826 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
    827