1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/framework/node_def_util.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/platform/macros.h" 26 27 namespace tensorflow { 28 29 class ShapeRefiner; 30 class ShapeRefinerTest; 31 32 namespace grappler { 33 class GraphProperties; 34 class SymbolicShapeManager; 35 } // namespace grappler 36 37 namespace shape_inference { 38 39 struct DimensionOrConstant; 40 class InferenceContext; 41 42 // Dimension values are accessed through InferenceContext. 43 class Dimension { 44 private: 45 Dimension(); 46 Dimension(int64 value); 47 ~Dimension() {} 48 49 const int64 value_; 50 51 friend class InferenceContext; 52 friend class ShapeManager; 53 TF_DISALLOW_COPY_AND_ASSIGN(Dimension); 54 }; 55 56 class DimensionHandle { 57 public: 58 DimensionHandle() {} 59 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } 60 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } 61 62 private: 63 DimensionHandle(const Dimension* dim) { ptr_ = dim; } 64 65 const Dimension* operator->() const { return ptr_; } 66 bool IsSet() const { return ptr_ != nullptr; } 67 68 const Dimension* ptr_ = nullptr; 69 70 friend struct DimensionOrConstant; 71 friend class InferenceContext; 72 friend class ShapeInferenceTest; 73 friend class ShapeInferenceTestutil; 74 friend class ::tensorflow::ShapeRefinerTest; 75 friend class ShapeManager; 76 friend class ::tensorflow::grappler::GraphProperties; 77 friend class ::tensorflow::grappler::SymbolicShapeManager; 78 79 // Intentionally copyable. 80 }; 81 82 // Shape rank and dimensions are accessed through InferenceContext. 83 class Shape { 84 private: 85 Shape(); 86 Shape(const std::vector<DimensionHandle>& dims); 87 ~Shape() {} 88 89 const int32 rank_; 90 const std::vector<DimensionHandle> dims_; 91 92 friend class InferenceContext; 93 friend class ShapeManager; 94 friend class ::tensorflow::grappler::SymbolicShapeManager; 95 96 TF_DISALLOW_COPY_AND_ASSIGN(Shape); 97 }; 98 99 class ShapeHandle { 100 public: 101 ShapeHandle() {} 102 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } 103 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } 104 105 private: 106 ShapeHandle(const Shape* shape) { ptr_ = shape; } 107 const Shape* operator->() const { return ptr_; } 108 bool IsSet() const { return ptr_ != nullptr; } 109 110 const Shape* ptr_ = nullptr; 111 112 friend class InferenceContext; 113 friend class ShapeInferenceTest; 114 friend class ShapeInferenceTestutil; 115 friend class ::tensorflow::ShapeRefinerTest; 116 friend class ShapeManager; 117 friend class ::tensorflow::grappler::SymbolicShapeManager; 118 119 // Intentionally copyable. 120 }; 121 122 // Struct used to allow functions to take DimensionHandle or a dimension value. 123 // Not meant to be constructed directly. 124 struct DimensionOrConstant { 125 public: 126 // Intentionally not explicit. 127 DimensionOrConstant(DimensionHandle dim); 128 129 // val must be non-negative or InferenceContext::kUnknownDim. 130 DimensionOrConstant(int64 val); 131 132 // dim takes precedence. If dim != nullptr, val is ignored. 133 DimensionHandle dim; 134 int64 val; 135 136 private: 137 DimensionOrConstant(); 138 }; 139 140 struct ShapeAndType { 141 ShapeAndType() {} 142 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} 143 144 ShapeHandle shape; 145 DataType dtype = DT_INVALID; 146 }; 147 148 // Shape inference functions registered on ops in REGISTER_OP implement 149 // their shape functions in terms of this InferenceContext. An InferenceContext 150 // is created by the framework and passed to a shape inference function. The 151 // shape inference function calls functions on the context, and should call 152 // set_output() to set the shape on all outputs. 153 // 154 // To infer shapes for user-defined functions see ShapeRefiner. 155 // 156 // All Shape* and Dimension* returned by functions of InferenceContext are owned 157 // by the InferenceContext. 158 class InferenceContext { 159 public: 160 static constexpr int64 kUnknownDim = -1; 161 static constexpr int32 kUnknownRank = -1; 162 163 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 164 // 165 // Elements of <input_tensors_as_shapes> are used for when a shape function 166 // makes a call to MakeShapeFromShapeTensor; in particular, when the 167 // input_tensors[i] is nullptr but the shape represented by it is partially 168 // known from analysis of the graph. 169 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>. 170 // Values of <input_tensors_as_shapes> do not need to outlive the context. 171 // 172 // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. 173 InferenceContext(int graph_def_version, const NodeDef* node_def, 174 const OpDef& op_def, 175 const std::vector<ShapeHandle>& input_shapes, 176 const std::vector<const Tensor*>& input_tensors, 177 const std::vector<ShapeHandle>& input_tensors_as_shapes, 178 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 179 input_handle_shapes_and_types); 180 181 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 182 // 183 // Elements of <input_tensors_as_shapes> are used for when a shape 184 // function makes a call to MakeShapeFromShapeTensor; in particular, when 185 // the input_tensors[i] is nullptr but the shape represented by it is 186 // partially known from analysis of the graph. <input_tensors_as_shapes> 187 // can have fewer elements than <input_shapes>. Values of 188 // <input_tensors_as_shapes> do not need to outlive the context. 189 // 190 // REQUIRES: <node_def> is not NULL, and must outlive the 191 // InferenceContext. 192 InferenceContext( 193 int graph_def_version, const NodeDef* node_def, const OpDef& op_def, 194 const std::vector<TensorShapeProto>& input_shapes, 195 const std::vector<const Tensor*>& input_tensors, 196 const std::vector<TensorShapeProto>& input_tensors_as_shapes, 197 const std::vector< 198 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>& 199 input_handle_shapes_and_types); 200 201 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 202 // 203 // Elements of <input_tensors_as_shapes> are used for when a shape 204 // function makes a call to MakeShapeFromShapeTensor; in particular, when 205 // the input_tensors[i] is nullptr but the shape represented by it is 206 // partially known from analysis of the graph. <input_tensors_as_shapes> 207 // can have fewer elements than <input_shapes>. Values of 208 // <input_tensors_as_shapes> do not need to outlive the context. 209 // 210 // REQUIRES: <node_def> is not NULL, and must outlive the 211 // InferenceContext. 212 InferenceContext( 213 int graph_def_version, const NodeDef* node_def, const OpDef& op_def, 214 const std::vector<PartialTensorShape>& input_shapes, 215 const std::vector<const Tensor*>& input_tensors, 216 const std::vector<PartialTensorShape>& input_tensors_as_shapes, 217 const std::vector<std::unique_ptr< 218 std::vector<std::pair<PartialTensorShape, DataType>>>>& 219 input_handle_shapes_and_types); 220 221 ~InferenceContext(); 222 223 // Runs the shape inference function 'fn' with 'this' as the 224 // argument, returns the status of the inference. 225 // 226 // On error, additional context is provided in the error message. 227 Status Run( 228 const std::function<Status(shape_inference::InferenceContext* c)>& fn); 229 230 // Merge the stored shape of the input in position idx with <shape> according 231 // to the following rules: 232 // 233 // - If the ShapeHandles are the same or <shape> is unknown, there will be no 234 // change. Otherwise if the stored shape is unknown, the new shape will be 235 // <shape>. 236 // - If both shapes are known, then they must have the same rank. 237 // - For any one dimension, if the values for that dimension in both shapes 238 // are known, then the values must match. 239 // - If one shape has equal or more information than the other shape in every 240 // dimension, the new shape will become the shape with more information. 241 // - Example: merging [2,?] and [?,2] results in [2,2] 242 // - Example: [2,2] cannot be merged with [1,2] 243 // 244 // This requires idx to be in the [0, num_inputs) range. If the merge is 245 // successful, return true. Return false otherwise. 246 bool MergeInput(int idx, ShapeHandle shape) { 247 ShapeHandle new_shape; 248 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; 249 inputs_[idx] = new_shape; 250 return true; 251 } 252 253 // Relax the stored shape of the input in position idx with <shape> according 254 // to the following rules: 255 // 256 // - If the ShapeHandles are the same then the stored shape will be returned. 257 // - If either of the ShapeHandles are unknown, then a new UnknownShape will 258 // be returned. A new shape must be returned because we cannot claim that 259 // the resulting shape is necessarily the same as either of the input 260 // shapes. 261 // - If the shapes both have known ranks but their ranks are different, a new 262 // UnknownShape will be returned. 263 // - For any one dimension, if the value for that dimension in either of the 264 // shapes is unknown, a new shape will be returned with a new UnknownDim in 265 // that dimension. 266 // - For any one dimension, if the values for that dimension in both shapes 267 // are known but do not match, a new shape will be returned with a new 268 // UnknownDim in that dimension. 269 // - If both shapes have the same known rank and match in every dimension, 270 // the stored shape will be returned. 271 // - Example: relaxing [2,?] and [?,2] results in [?,?] 272 // - Example: relaxing [2,2] and [3,2] results in [?,2] 273 // - Example: relaxing [2,2] with [1,2,3] results in ? 274 // 275 // This requires idx to be in the [0, num_inputs) range. If the relax is 276 // successful and the new shape differs from the old one, store the new 277 // shape and return true. Return false otherwise. 278 bool RelaxInput(int idx, ShapeHandle shape) { 279 ShapeHandle new_shape; 280 Relax(inputs_[idx], shape, &new_shape); 281 if (inputs_[idx].SameHandle(new_shape)) { 282 return false; 283 } 284 inputs_[idx] = new_shape; 285 return true; 286 } 287 288 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; } 289 290 ShapeHandle input(int64 idx) const { return inputs_[idx]; } 291 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; 292 int num_inputs() const { return inputs_.size(); } 293 294 // Returns the input tensor at index <idx>, or nullptr if the input tensor is 295 // not available at the time of shape inference. 296 const Tensor* input_tensor(int idx) { 297 // Mark that this idx was requested. 298 requested_input_tensor_[idx] = true; 299 return input_tensors_[idx]; 300 } 301 302 // Returns true iff input_tensor(idx) was called by the shape function. 303 bool requested_input_tensor(int idx) const { 304 return requested_input_tensor_[idx]; 305 } 306 307 // Returns true if MakeShapeFromInputTensor was called but the constant 308 // input_tensor was not present. 309 bool requested_input_tensor_as_partial_shape(int idx) const { 310 return requested_input_tensor_as_partial_shape_[idx]; 311 } 312 313 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) { 314 input_tensors_ = input_tensors; 315 } 316 317 void set_input_tensors_as_shapes( 318 const std::vector<ShapeHandle>& input_tensors_as_shapes) { 319 input_tensors_as_shapes_ = input_tensors_as_shapes; 320 } 321 322 const std::vector<ShapeHandle>& input_tensors_as_shapes() const { 323 return input_tensors_as_shapes_; 324 } 325 326 ShapeHandle output(int64 idx) const { return outputs_.at(idx); } 327 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } 328 Status set_output(StringPiece output_name, 329 const std::vector<ShapeHandle>& shapes); 330 331 int num_outputs() const { return outputs_.size(); } 332 ShapeHandle output(int idx) const { return outputs_.at(idx); } 333 Status output(StringPiece output_name, 334 std::vector<ShapeHandle>* output) const; 335 336 AttrSlice attrs() const { return AttrSlice(*node_def_); } 337 338 string op() const; 339 340 // idx can be negative for an offset from end of dimensions. 341 // idx must be in the range [-1 * s.rank, s.rank). 342 DimensionHandle Dim(ShapeHandle s, int64 idx) { 343 if (s->rank_ == kUnknownRank) { 344 return UnknownDim(); 345 } 346 return DimKnownRank(s, idx); 347 } 348 // As above, but asserts that the rank of the shape is known. 349 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) { 350 CHECK_NE(s->rank_, kUnknownRank); 351 if (idx < 0) { 352 return s->dims_[s->dims_.size() + idx]; 353 } 354 return s->dims_[idx]; 355 } 356 357 static int32 Rank(ShapeHandle s) { 358 DCHECK(s.IsSet()); 359 return s.IsSet() ? s->rank_ : kUnknownRank; 360 } 361 static bool RankKnown(ShapeHandle s) { 362 return (s.IsSet() && (Rank(s) != kUnknownRank)); 363 } 364 static inline int64 Value(DimensionOrConstant d) { 365 return d.dim.IsSet() ? d.dim->value_ : d.val; 366 } 367 static inline bool ValueKnown(DimensionOrConstant d) { 368 return Value(d) != kUnknownDim; 369 } 370 371 // Fills the output proto with the shape defined by the handle. 372 // "proto" is expected to be empty prior to the call. 373 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); 374 375 // Returns true if the rank and all dimensions of the Shape are known. 376 bool FullyDefined(ShapeHandle s); 377 378 // Returns the total number of elements, or an unknown dimension for an 379 // incomplete shape. 380 DimensionHandle NumElements(ShapeHandle s); 381 382 string DebugString(ShapeHandle s); 383 string DebugString(DimensionHandle d); 384 string DebugString(const ShapeAndType& shape_and_type); 385 string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types); 386 387 // Describes the whole context, for debugging purposes. 388 string DebugString() const; 389 390 // If <shape> has rank <rank>, or its rank is unknown, return OK and return 391 // the shape with asserted rank in <*out>. Otherwise return an error. 392 // 393 // Note that <*out> may be set to <shape>. 394 Status WithRank(ShapeHandle shape, int64 rank, 395 ShapeHandle* out) TF_MUST_USE_RESULT; 396 Status WithRankAtLeast(ShapeHandle shape, int64 rank, 397 ShapeHandle* out) TF_MUST_USE_RESULT; 398 Status WithRankAtMost(ShapeHandle shape, int64 rank, 399 ShapeHandle* out) TF_MUST_USE_RESULT; 400 401 // If <dim> has value <value>, or its value is unknown, returns OK and returns 402 // the dimension with asserted value in <*out>. Otherwise returns an error. 403 // 404 // Note that <*out> may be set to <dim>. 405 Status WithValue(DimensionHandle dim, int64 value, 406 DimensionHandle* out) TF_MUST_USE_RESULT; 407 408 // Merges <s0> and <s1> and returns the merged shape in <*out>. See 409 // 'MergeInput' function for full details and examples. 410 Status Merge(ShapeHandle s0, ShapeHandle s1, 411 ShapeHandle* out) TF_MUST_USE_RESULT; 412 413 // Asserts that <s>'s rank >= <prefix>'s rank, and the first 414 // <prefix.rank> dimensions of <s> are compatible with the dimensions of 415 // <prefix>. 416 // Returns the merged results in <*s_out> and <*prefix_out>. 417 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, 418 ShapeHandle* prefix_out) TF_MUST_USE_RESULT; 419 420 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0> 421 // and <d1> have incompatible values, returns an error. 422 // 423 // Note that <*out> may be set to <d0> or <d1>. 424 Status Merge(DimensionHandle d0, DimensionHandle d1, 425 DimensionHandle* out) TF_MUST_USE_RESULT; 426 427 // Returns in <*out> a sub-shape of <s> with dimensions [start:]. 428 // <start> can be negative to index from the end of the shape. If <start> > 429 // rank of <s>, then an empty subshape is returned. 430 Status Subshape(ShapeHandle s, int64 start, 431 ShapeHandle* out) TF_MUST_USE_RESULT; 432 433 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end]. 434 // <start> and <end> can be negative, to index from the end of the shape. 435 // <start> and <end> are set to the rank of <s> if > rank of <s>. 436 Status Subshape(ShapeHandle s, int64 start, int64 end, 437 ShapeHandle* out) TF_MUST_USE_RESULT; 438 439 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride]. 440 // <start> and <end> can be negative, to index from the end of the shape. 441 // <start> and <end> are set to the rank of <s> if > rank of <s>. 442 // <stride> can be negative, to reverse the <s>. 443 Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride, 444 ShapeHandle* out) TF_MUST_USE_RESULT; 445 446 // Returns in <*out> the result of appending the dimensions of <s2> to those 447 // of <s1>. 448 Status Concatenate(ShapeHandle s1, ShapeHandle s2, 449 ShapeHandle* out) TF_MUST_USE_RESULT; 450 451 // Returns in <out> the shape from replacing <s.dim[dim_index]> with 452 // <new_dim>. 453 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim, 454 ShapeHandle* out) TF_MUST_USE_RESULT; 455 456 // Returns a new shape with the given dims. The returned value is owned by 457 // this context. 458 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); 459 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims); 460 461 // Returns a new unknown shape. 462 ShapeHandle UnknownShape(); 463 464 // Returns a shape with specified rank but unknown dims. 465 ShapeHandle UnknownShapeOfRank(int64 rank); 466 467 // Returns a new shape of zero dimensions. 468 ShapeHandle Scalar(); 469 470 // Returns a new shape of one dimension. 471 ShapeHandle Vector(DimensionOrConstant dim); 472 473 // Returns a new shape of two dimensions. 474 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); 475 476 // Returns in <out> a new shape whose dimension sizes come from input tensor 477 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If 478 // the input tensor is NULL, then an unknown shape is returned. 479 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); 480 481 // Like the function above, but treats scalar values as unknown 482 // shapes. **NOTE** If the scalar is statically known, its value 483 // must be -1 or an error is returned. 484 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx, 485 ShapeHandle* out); 486 487 // Returns in <out> a new shape corresponding to <proto>. 488 Status MakeShapeFromShapeProto(const TensorShapeProto& proto, 489 ShapeHandle* out); 490 491 // Returns in <out> a new shape corresponding to <partial_shape>. 492 Status MakeShapeFromPartialTensorShape( 493 const PartialTensorShape& partial_shape, ShapeHandle* out); 494 495 // Returns in <out> a new shape corresponding to <shape>. 496 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); 497 498 // Returns a new dimension of the given size. The returned value is owned by 499 // this context. 500 inline DimensionHandle MakeDim(DimensionOrConstant d) { 501 return shape_manager_.MakeDim(d); 502 } 503 504 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } 505 506 // Returns in <val> a scalar value from an input tensor <t>. The input tensor 507 // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the 508 // input tensor is not NULL. 509 Status GetScalarFromTensor(const Tensor* t, int64* val); 510 511 // Returns a new dimension whose value is given by a scalar input tensor. 512 // The input tensor must be in host memory, since it is dereferenced to get 513 // the value. 514 Status MakeDimForScalarInput(int idx, DimensionHandle* out); 515 516 // Returns a new dimension whose value is given by a scalar input tensor. 517 // This allows for a negative input dimension given the rank of a separate 518 // tensor. This rank can be negative if unknown. 519 // The input tensor must be in host memory, since it is dereferenced to get 520 // the value. 521 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, 522 DimensionHandle* out); 523 524 // Look up the attr for the NodeDef being evaluated with name attr_name and 525 // set *value to its value. If no attr with attr_name is found in def(), or 526 // the attr does not have a matching type, a non-ok status will be returned. 527 template <class T> 528 Status GetAttr(StringPiece attr_name, T* value) const; 529 530 // Returns in <out> the result of dividing <dividend> by <divisor>. 531 // Returns an error if <divisor> is not positive or if <evenly_divisible> 532 // and <divisor> does not evenly divide <dividend>. 533 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, 534 bool evenly_divisible, DimensionHandle* out); 535 536 // Returns in <out> the sum of <first> and <second>. 537 Status Add(DimensionHandle first, DimensionOrConstant second, 538 DimensionHandle* out); 539 540 // Returns in <out> the dimension that is <first> minus <second>. 541 Status Subtract(DimensionHandle first, DimensionOrConstant second, 542 DimensionHandle* out); 543 544 // Returns in <out> the product of <first> and <second>. 545 Status Multiply(DimensionHandle first, DimensionOrConstant second, 546 DimensionHandle* out); 547 548 // Returns in <out> the minimum of <first> and <second>. If either <first> or 549 // <second> is zero the results is zero. Otherwise, if either <first> or 550 // <second> is unknown the results is unknown. 551 Status Min(DimensionHandle first, DimensionOrConstant second, 552 DimensionHandle* out); 553 554 // Returns in <out> the maximum of <first> and <second>. If either <first> or 555 // <second> is unknown the results is unknown. 556 Status Max(DimensionHandle first, DimensionOrConstant second, 557 DimensionHandle* out); 558 559 Status construction_status() const { return construction_status_; } 560 561 // Methods to propagate shape and dtype on edges of handles. Handles are the 562 // dtype DT_RESOURCE which can be used to access state stored in a 563 // ResourceManager. When ops (such as variables) consume these handles to 564 // produce tensors they might need to know side-information about the shapes 565 // and dtypes of tensors which can be accessed via the handle. These methods 566 // propagate that information. Output handle dtypes and shapes are ignored if 567 // the output tensor is not of type DT_RESOURCE. 568 569 // Merge the stored shapes and types corresponding to the input handle in 570 // position idx with the specified shapes and types. This requires idx to be 571 // in the [0, num_inputs) range. 572 // 573 // If the merge is successful and any of the new shapes differs from the old 574 // one, or any of the old dtypes was DT_INVALID, store the new shapes and 575 // return true. Return false otherwise. 576 // 577 // See 'MergeInput' function for full details and examples. 578 bool MergeInputHandleShapesAndTypes( 579 int idx, 580 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 581 582 // As MergeInputHandleShapesAndTypes, but for an output. 583 bool MergeOutputHandleShapesAndTypes( 584 int idx, 585 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 586 587 // Relaxes the stored shapes and types corresponding to the input handle in 588 // position idx with the specified shapes and types. This requires idx to be 589 // in the [0, num_inputs) range. 590 // 591 // If the relax is successful (sizes are the same, old dtypes match new ones 592 // or are DT_INVALID), then store the relaxed shapes and return true. 593 // Return false otherwise. 594 // 595 // See 'RelaxInput' function for full details and examples. 596 bool RelaxInputHandleShapesAndMergeTypes( 597 int idx, 598 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 599 600 // As RelaxInputHandleShapesAndTypes, but for an output. 601 bool RelaxOutputHandleShapesAndMergeTypes( 602 int idx, 603 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 604 605 void set_input_handle_shapes_and_types( 606 int idx, const std::vector<ShapeAndType>& shapes_and_types) { 607 input_handle_shapes_and_types_[idx].reset( 608 new std::vector<ShapeAndType>(shapes_and_types)); 609 } 610 611 // Returns the output handle shapes and types, for the resource tensor output 612 // at index <idx>. Returns NULL if the shape and types were never set. 613 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) { 614 return output_handle_shapes_and_types_[idx].get(); 615 } 616 617 // Returns the inputs handle shapes and types, for the resource tensor output 618 // at index <idx>. Returns NULL if the shape and types were not available. 619 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) { 620 return input_handle_shapes_and_types_[idx].get(); 621 } 622 623 void set_output_handle_shapes_and_types( 624 int idx, const std::vector<ShapeAndType>& shapes_and_types) { 625 output_handle_shapes_and_types_[idx].reset( 626 new std::vector<ShapeAndType>(shapes_and_types)); 627 } 628 629 // Note that shape functions should usually call MakeShapeFromShapeTensor, 630 // as it does more analysis to provide partial shapes. 631 // 632 // Returns in <out> a new shape whose dimension sizes come from tensor <t>. 633 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL, 634 // then an unknown shape is returned. 635 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, 636 ShapeHandle* out); 637 638 int graph_def_version() const { return graph_def_version_; } 639 640 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const { 641 return merged_shapes_; 642 } 643 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims() 644 const { 645 return merged_dims_; 646 } 647 648 // Adds new outputs; useful when mutating the graph. 649 Status ExpandOutputs(int new_output_size); 650 651 private: 652 // Creates and stores shapes for use in InferenceContext. 653 class ShapeManager { 654 public: 655 ShapeManager(); 656 ~ShapeManager(); 657 658 // Returns a new shape with the given dims. The returned value is owned by 659 // this class. 660 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); 661 662 // Returns a new unknown shape. 663 ShapeHandle UnknownShape(); 664 665 // Returns a new dimension of the given size. The returned value 666 // is owned by this class. 667 inline DimensionHandle MakeDim(DimensionOrConstant d) { 668 if (d.dim.IsSet()) { 669 return d.dim; 670 } else { 671 all_dims_.push_back(new Dimension(d.val)); 672 return all_dims_.back(); 673 } 674 } 675 676 private: 677 std::vector<Shape*> all_shapes_; // values are owned. 678 std::vector<Dimension*> all_dims_; // values are owned. 679 }; 680 681 friend class ::tensorflow::grappler::GraphProperties; 682 683 // Friend for user-defined function shape inference purposes. 684 friend class ::tensorflow::ShapeRefiner; 685 686 friend class ShapeInferenceTest; // For testing Relax functions. 687 friend class ShapeInferenceTestutil; // For testing shapes. 688 689 // Shared initialization across the two constructors. Remove 690 // once we get rid of one of them. 691 void PreInputInit(const OpDef& op_def, 692 const std::vector<const Tensor*>& input_tensors, 693 const std::vector<ShapeHandle>& input_tensors_as_shapes); 694 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 695 input_handle_data); 696 697 DimensionHandle GetDimension(const DimensionOrConstant& d); 698 699 Status ReturnUnknownShape(ShapeHandle* out) { 700 *out = UnknownShape(); 701 return Status::OK(); 702 } 703 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims, 704 ShapeHandle* out) { 705 *out = MakeShape(dims); 706 return Status::OK(); 707 } 708 709 // Adds additional context to the given status. 710 Status AttachContext(const Status& status); 711 712 // Relaxes an existing value <d_old> with a new value <d_new> and returns the 713 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible 714 // values, returns an error. 715 // 716 // Note that <*out> may be set to <d_old> or <d_new>. 717 void Relax(DimensionHandle d_old, DimensionHandle d_new, 718 DimensionHandle* out); 719 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the 720 // relaxed shape in <*out>. See 'RelaxInput' function for full details and 721 // examples. 722 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); 723 724 // Used to implement MergeInputHandleShapesAndTypes and 725 // MergeOutputHandleShapesAndTypes. 726 bool MergeHandleShapesAndTypes( 727 const std::vector<ShapeAndType>& shapes_and_types, 728 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; 729 // Used to implement RelaxInputHandleShapesAndMergeTypes and 730 // RelaxOutputHandleShapesAndMergeTypes. 731 bool RelaxHandleShapesAndMergeTypes( 732 const std::vector<ShapeAndType>& shapes_and_types, 733 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; 734 735 // Forget all the previous merged shapes and dims. 736 void ForgetMerges() { 737 merged_shapes_.clear(); 738 merged_dims_.clear(); 739 } 740 741 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor. 742 Status InternalMakeShapeFromTensor( 743 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, 744 ShapeHandle tensor_shape, ShapeHandle* out); 745 746 ShapeManager shape_manager_; 747 748 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from 749 // `shape_manager_`. 750 std::vector<ShapeHandle> inputs_; 751 std::vector<const Tensor*> input_tensors_; 752 std::vector<bool> requested_input_tensor_; 753 std::vector<ShapeHandle> outputs_; 754 // Can have fewer elements than inputs_. 755 std::vector<ShapeHandle> input_tensors_as_shapes_; 756 std::vector<bool> requested_input_tensor_as_partial_shape_; 757 758 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available 759 // through the resource handle passed along input i of the node. 760 // 761 // Values may be NULL. 762 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 763 input_handle_shapes_and_types_; 764 765 // output_handle_shapes_and_types_[i] is the list of shape/type pairs 766 // available through the resource handle passed along output i of the node. 767 // 768 // Values may be NULL. 769 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 770 output_handle_shapes_and_types_; 771 772 const int graph_def_version_; 773 const NodeDef* node_def_; 774 NameRangeMap input_name_map_; 775 NameRangeMap output_name_map_; 776 777 // An error set during construction. TODO(cwhipkey): remove when test 778 // constructor is removed. 779 Status construction_status_; 780 781 // Pair of shape or dim handles that are equivalent, ie that represent the 782 // same underlying shape of dimension. Note that for each pair at least one of 783 // the handles must contain an unknown shape, since we don't keep track of 784 // known shapes or dims here. 785 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_; 786 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_; 787 788 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); 789 }; 790 791 // ----------------------------------------------------------------------------- 792 // Template and inline method implementations, please ignore 793 794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} 795 inline Dimension::Dimension(int64 value) : value_(value) { 796 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) 797 << "Dimension must be non-negative or equal to " 798 "InferenceContext::kUnknownDim but got " 799 << value; 800 } 801 802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} 803 inline Shape::Shape(const std::vector<DimensionHandle>& dims) 804 : rank_(dims.size()), dims_(dims) {} 805 806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) 807 : dim(dim) { 808 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension."; 809 } 810 811 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) { 812 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) 813 << "Dimension must be non-negative or equal to " 814 "InferenceContext::kUnknownDim but got " 815 << val; 816 } 817 818 template <class T> 819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { 820 return GetNodeAttr(*node_def_, attr_name, value); 821 } 822 823 } // namespace shape_inference 824 } // namespace tensorflow 825 826 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 827