1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/framework/node_def_util.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/platform/macros.h" 26 27 namespace tensorflow { 28 29 class ShapeRefiner; 30 class ShapeRefinerTest; 31 32 namespace grappler { 33 class GraphProperties; 34 class SymbolicShapeManager; 35 } // namespace grappler 36 37 namespace shape_inference { 38 39 struct DimensionOrConstant; 40 class InferenceContext; 41 42 // Dimension values are accessed through InferenceContext. 43 class Dimension { 44 private: 45 Dimension(); 46 Dimension(int64 value); 47 ~Dimension() {} 48 49 const int64 value_; 50 51 friend class InferenceContext; 52 friend class ShapeManager; 53 TF_DISALLOW_COPY_AND_ASSIGN(Dimension); 54 }; 55 56 class DimensionHandle { 57 public: 58 DimensionHandle() {} 59 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } 60 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } 61 62 private: 63 DimensionHandle(const Dimension* dim) { ptr_ = dim; } 64 65 const Dimension* operator->() const { return ptr_; } 66 bool IsSet() const { return ptr_ != nullptr; } 67 68 const Dimension* ptr_ = nullptr; 69 70 friend struct DimensionOrConstant; 71 friend class InferenceContext; 72 friend class ShapeInferenceTest; 73 friend class ShapeInferenceTestutil; 74 friend class ::tensorflow::ShapeRefinerTest; 75 friend class ShapeManager; 76 friend class ::tensorflow::grappler::GraphProperties; 77 friend class ::tensorflow::grappler::SymbolicShapeManager; 78 79 // Intentionally copyable. 80 }; 81 82 // Shape rank and dimensions are accessed through InferenceContext. 83 class Shape { 84 private: 85 Shape(); 86 Shape(const std::vector<DimensionHandle>& dims); 87 ~Shape() {} 88 89 const int32 rank_; 90 const std::vector<DimensionHandle> dims_; 91 92 friend class InferenceContext; 93 friend class ShapeManager; 94 friend class ::tensorflow::grappler::SymbolicShapeManager; 95 96 TF_DISALLOW_COPY_AND_ASSIGN(Shape); 97 }; 98 99 class ShapeHandle { 100 public: 101 ShapeHandle() {} 102 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } 103 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } 104 105 private: 106 ShapeHandle(const Shape* shape) { ptr_ = shape; } 107 const Shape* operator->() const { return ptr_; } 108 bool IsSet() const { return ptr_ != nullptr; } 109 110 const Shape* ptr_ = nullptr; 111 112 friend class InferenceContext; 113 friend class ShapeInferenceTest; 114 friend class ShapeInferenceTestutil; 115 friend class ::tensorflow::ShapeRefinerTest; 116 friend class ShapeManager; 117 friend class ::tensorflow::grappler::SymbolicShapeManager; 118 119 // Intentionally copyable. 120 }; 121 122 // Struct used to allow functions to take DimensionHandle or a dimension value. 123 // Not meant to be constructed directly. 124 struct DimensionOrConstant { 125 public: 126 // Intentionally not explicit. 127 DimensionOrConstant(DimensionHandle dim); 128 129 // val must be non-negative or InferenceContext::kUnknownDim. 130 DimensionOrConstant(int64 val); 131 132 // dim takes precedence. If dim != nullptr, val is ignored. 133 DimensionHandle dim; 134 int64 val; 135 136 private: 137 DimensionOrConstant(); 138 }; 139 140 struct ShapeAndType { 141 ShapeAndType() {} 142 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} 143 144 ShapeHandle shape; 145 DataType dtype = DT_INVALID; 146 }; 147 148 // Shape inference functions registered on ops in REGISTER_OP implement 149 // their shape functions in terms of this InferenceContext. An InferenceContext 150 // is created by the framework and passed to a shape inference function. The 151 // shape inference function calls functions on the context, and should call 152 // set_output() to set the shape on all outputs. 153 // 154 // To infer shapes for user-defined functions see ShapeRefiner. 155 // 156 // All Shape* and Dimension* returned by functions of InferenceContext are owned 157 // by the InferenceContext. 158 class InferenceContext { 159 public: 160 static constexpr int64 kUnknownDim = -1; 161 static constexpr int32 kUnknownRank = -1; 162 163 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 164 // 165 // Elements of <input_tensors_as_shapes> are used for when a shape function 166 // makes a call to MakeShapeFromShapeTensor; in particular, when the 167 // input_tensors[i] is nullptr but the shape represented by it is partially 168 // known from analysis of the graph. 169 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>. 170 // Values of <input_tensors_as_shapes> do not need to outlive the context. 171 // 172 // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. 173 InferenceContext(int graph_def_version, const NodeDef* node_def, 174 const OpDef& op_def, 175 const std::vector<ShapeHandle>& input_shapes, 176 const std::vector<const Tensor*>& input_tensors, 177 const std::vector<ShapeHandle>& input_tensors_as_shapes, 178 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 179 input_handle_shapes_and_types); 180 181 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 182 // 183 // Elements of <input_tensors_as_shapes> are used for when a shape 184 // function makes a call to MakeShapeFromShapeTensor; in particular, when 185 // the input_tensors[i] is nullptr but the shape represented by it is 186 // partially known from analysis of the graph. <input_tensors_as_shapes> 187 // can have fewer elements than <input_shapes>. Values of 188 // <input_tensors_as_shapes> do not need to outlive the context. 189 // 190 // REQUIRES: <node_def> is not NULL, and must outlive the 191 // InferenceContext. 192 InferenceContext( 193 int graph_def_version, const NodeDef* node_def, const OpDef& op_def, 194 const std::vector<TensorShapeProto>& input_shapes, 195 const std::vector<const Tensor*>& input_tensors, 196 const std::vector<TensorShapeProto>& input_tensors_as_shapes, 197 const std::vector< 198 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>& 199 input_handle_shapes_and_types); 200 201 // <input_tensors> is NULL-padded to be the same size as <input_shapes>. 202 // 203 // Elements of <input_tensors_as_shapes> are used for when a shape 204 // function makes a call to MakeShapeFromShapeTensor; in particular, when 205 // the input_tensors[i] is nullptr but the shape represented by it is 206 // partially known from analysis of the graph. <input_tensors_as_shapes> 207 // can have fewer elements than <input_shapes>. Values of 208 // <input_tensors_as_shapes> do not need to outlive the context. 209 // 210 // REQUIRES: <node_def> is not NULL, and must outlive the 211 // InferenceContext. 212 InferenceContext( 213 int graph_def_version, const NodeDef* node_def, const OpDef& op_def, 214 const std::vector<PartialTensorShape>& input_shapes, 215 const std::vector<const Tensor*>& input_tensors, 216 const std::vector<PartialTensorShape>& input_tensors_as_shapes, 217 const std::vector<std::unique_ptr< 218 std::vector<std::pair<PartialTensorShape, DataType>>>>& 219 input_handle_shapes_and_types); 220 221 ~InferenceContext(); 222 223 // Runs the shape inference function 'fn' with 'this' as the 224 // argument, returns the status of the inference. 225 // 226 // On error, additional context is provided in the error message. 227 Status Run( 228 const std::function<Status(shape_inference::InferenceContext* c)>& fn); 229 230 // Merge the stored shape of the input in position idx with <shape> according 231 // to the following rules: 232 // 233 // - If the ShapeHandles are the same or <shape> is unknown, there will be no 234 // change. Otherwise if the stored shape is unknown, the new shape will be 235 // <shape>. 236 // - If both shapes are known, then they must have the same rank. 237 // - For any one dimension, if the values for that dimension in both shapes 238 // are known, then the values must match. 239 // - If one shape has equal or more information than the other shape in every 240 // dimension, the new shape will become the shape with more information. 241 // - Example: merging [2,?] and [?,2] results in [2,2] 242 // - Example: [2,2] cannot be merged with [1,2] 243 // 244 // This requires idx to be in the [0, num_inputs) range. If the merge is 245 // successful, return true. Return false otherwise. 246 bool MergeInput(int idx, ShapeHandle shape) { 247 ShapeHandle new_shape; 248 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; 249 inputs_[idx] = new_shape; 250 return true; 251 } 252 253 // Relax the stored shape of the input in position idx with <shape> according 254 // to the following rules: 255 // 256 // - If the ShapeHandles are the same then the stored shape will be returned. 257 // - If either of the ShapeHandles are unknown, then a new UnknownShape will 258 // be returned. A new shape must be returned because we cannot claim that 259 // the resulting shape is necessarily the same as either of the input 260 // shapes. 261 // - If the shapes both have known ranks but their ranks are different, a new 262 // UnknownShape will be returned. 263 // - For any one dimension, if the value for that dimension in either of the 264 // shapes is unknown, a new shape will be returned with a new UnknownDim in 265 // that dimension. 266 // - For any one dimension, if the values for that dimension in both shapes 267 // are known but do not match, a new shape will be returned with a new 268 // UnknownDim in that dimension. 269 // - If both shapes have the same known rank and match in every dimension, 270 // the stored shape will be returned. 271 // - Example: relaxing [2,?] and [?,2] results in [?,?] 272 // - Example: relaxing [2,2] and [3,2] results in [?,2] 273 // - Example: relaxing [2,2] with [1,2,3] results in ? 274 // 275 // This requires idx to be in the [0, num_inputs) range. If the relax is 276 // successful and the new shape differs from the old one, store the new 277 // shape and return true. Return false otherwise. 278 bool RelaxInput(int idx, ShapeHandle shape) { 279 ShapeHandle new_shape; 280 Relax(inputs_[idx], shape, &new_shape); 281 if (inputs_[idx].SameHandle(new_shape)) { 282 return false; 283 } 284 inputs_[idx] = new_shape; 285 return true; 286 } 287 288 ShapeHandle input(int64 idx) const { return inputs_[idx]; } 289 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; 290 int num_inputs() const { return inputs_.size(); } 291 292 // Returns the input tensor at index <idx>, or nullptr if the input tensor is 293 // not available at the time of shape inference. 294 const Tensor* input_tensor(int idx) { 295 // Mark that this idx was requested. 296 requested_input_tensor_[idx] = true; 297 return input_tensors_[idx]; 298 } 299 300 // Returns true iff input_tensor(idx) was called by the shape function. 301 bool requested_input_tensor(int idx) const { 302 return requested_input_tensor_[idx]; 303 } 304 305 // Returns true if MakeShapeFromInputTensor was called but the constant 306 // input_tensor was not present. 307 bool requested_input_tensor_as_partial_shape(int idx) const { 308 return requested_input_tensor_as_partial_shape_[idx]; 309 } 310 311 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) { 312 input_tensors_ = input_tensors; 313 } 314 315 void set_input_tensors_as_shapes( 316 const std::vector<ShapeHandle>& input_tensors_as_shapes) { 317 input_tensors_as_shapes_ = input_tensors_as_shapes; 318 } 319 320 void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; } 321 Status set_output(StringPiece output_name, 322 const std::vector<ShapeHandle>& shapes); 323 324 int num_outputs() const { return outputs_.size(); } 325 ShapeHandle output(int idx) const { return outputs_[idx]; } 326 Status output(StringPiece output_name, 327 std::vector<ShapeHandle>* output) const; 328 329 AttrSlice attrs() const { return AttrSlice(*node_def_); } 330 331 string op() const; 332 333 // idx can be negative for an offset from end of dimensions. 334 // idx must be in the range [-1 * s.rank, s.rank). 335 DimensionHandle Dim(ShapeHandle s, int64 idx) { 336 if (s->rank_ == kUnknownRank) { 337 return UnknownDim(); 338 } 339 return DimKnownRank(s, idx); 340 } 341 // As above, but asserts that the rank of the shape is known. 342 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) { 343 CHECK_NE(s->rank_, kUnknownRank); 344 if (idx < 0) { 345 return s->dims_[s->dims_.size() + idx]; 346 } 347 return s->dims_[idx]; 348 } 349 350 static int32 Rank(ShapeHandle s) { 351 DCHECK(s.IsSet()); 352 return s.IsSet() ? s->rank_ : kUnknownRank; 353 } 354 static bool RankKnown(ShapeHandle s) { 355 return (s.IsSet() && (Rank(s) != kUnknownRank)); 356 } 357 static inline int64 Value(DimensionOrConstant d) { 358 return d.dim.IsSet() ? d.dim->value_ : d.val; 359 } 360 static inline bool ValueKnown(DimensionOrConstant d) { 361 return Value(d) != kUnknownDim; 362 } 363 364 // Fills the output proto with the shape defined by the handle. 365 // "proto" is expected to be empty prior to the call. 366 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); 367 368 // Returns true if the rank and all dimensions of the Shape are known. 369 bool FullyDefined(ShapeHandle s); 370 371 // Returns the total number of elements, or an unknown dimension for an 372 // incomplete shape. 373 DimensionHandle NumElements(ShapeHandle s); 374 375 string DebugString(ShapeHandle s); 376 string DebugString(DimensionHandle d); 377 378 // Describes the whole context, for debugging purposes. 379 string DebugString() const; 380 381 // If <shape> has rank <rank>, or its rank is unknown, return OK and return 382 // the shape with asserted rank in <*out>. Otherwise return an error. 383 // 384 // Note that <*out> may be set to <shape>. 385 Status WithRank(ShapeHandle shape, int64 rank, 386 ShapeHandle* out) TF_MUST_USE_RESULT; 387 Status WithRankAtLeast(ShapeHandle shape, int64 rank, 388 ShapeHandle* out) TF_MUST_USE_RESULT; 389 Status WithRankAtMost(ShapeHandle shape, int64 rank, 390 ShapeHandle* out) TF_MUST_USE_RESULT; 391 392 // If <dim> has value <value>, or its value is unknown, returns OK and returns 393 // the dimension with asserted value in <*out>. Otherwise returns an error. 394 // 395 // Note that <*out> may be set to <dim>. 396 Status WithValue(DimensionHandle dim, int64 value, 397 DimensionHandle* out) TF_MUST_USE_RESULT; 398 399 // Merges <s0> and <s1> and returns the merged shape in <*out>. See 400 // 'MergeInput' function for full details and examples. 401 Status Merge(ShapeHandle s0, ShapeHandle s1, 402 ShapeHandle* out) TF_MUST_USE_RESULT; 403 404 // Asserts that <s>'s rank >= <prefix>'s rank, and the first 405 // <prefix.rank> dimensions of <s> are compatible with the dimensions of 406 // <prefix>. 407 // Returns the merged results in <*s_out> and <*prefix_out>. 408 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, 409 ShapeHandle* prefix_out) TF_MUST_USE_RESULT; 410 411 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0> 412 // and <d1> have incompatible values, returns an error. 413 // 414 // Note that <*out> may be set to <d0> or <d1>. 415 Status Merge(DimensionHandle d0, DimensionHandle d1, 416 DimensionHandle* out) TF_MUST_USE_RESULT; 417 418 // Returns in <*out> a sub-shape of <s> with dimensions [start:]. 419 // <start> can be negative to index from the end of the shape. If <start> > 420 // rank of <s>, then an empty subshape is returned. 421 Status Subshape(ShapeHandle s, int64 start, 422 ShapeHandle* out) TF_MUST_USE_RESULT; 423 424 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end]. 425 // <start> and <end> can be negative, to index from the end of the shape. 426 // <start> and <end> are set to the rank of <s> if > rank of <s>. 427 Status Subshape(ShapeHandle s, int64 start, int64 end, 428 ShapeHandle* out) TF_MUST_USE_RESULT; 429 430 // Returns in <*out> the result of appending the dimensions of <s2> to those 431 // of <s1>. 432 Status Concatenate(ShapeHandle s1, ShapeHandle s2, 433 ShapeHandle* out) TF_MUST_USE_RESULT; 434 435 // Returns in <out> the shape from replacing <s.dim[dim_index]> with 436 // <new_dim>. 437 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim, 438 ShapeHandle* out) TF_MUST_USE_RESULT; 439 440 // Returns a new shape with the given dims. The returned value is owned by 441 // this context. 442 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); 443 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims); 444 445 // Returns a new unknown shape. 446 ShapeHandle UnknownShape(); 447 448 // Returns a shape with specified rank but unknown dims. 449 ShapeHandle UnknownShapeOfRank(int64 rank); 450 451 // Returns a new shape of zero dimensions. 452 ShapeHandle Scalar(); 453 454 // Returns a new shape of one dimension. 455 ShapeHandle Vector(DimensionOrConstant dim); 456 457 // Returns a new shape of two dimensions. 458 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); 459 460 // Returns in <out> a new shape whose dimension sizes come from input tensor 461 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If 462 // the input tensor is NULL, then an unknown shape is returned. 463 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); 464 465 // Returns in <out> a new shape corresponding to <proto>. 466 Status MakeShapeFromShapeProto(const TensorShapeProto& proto, 467 ShapeHandle* out); 468 469 // Returns in <out> a new shape corresponding to <partial_shape>. 470 Status MakeShapeFromPartialTensorShape( 471 const PartialTensorShape& partial_shape, ShapeHandle* out); 472 473 // Returns in <out> a new shape corresponding to <shape>. 474 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); 475 476 // Returns a new dimension of the given size. The returned value is owned by 477 // this context. 478 inline DimensionHandle MakeDim(DimensionOrConstant d) { 479 return shape_manager_.MakeDim(d); 480 } 481 482 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } 483 484 // Returns in <val> a scalar value from an input tensor <t>. The input tensor 485 // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the 486 // input tensor is not NULL. 487 Status GetScalarFromTensor(const Tensor* t, int64* val); 488 489 // Returns a new dimension whose value is given by a scalar input tensor. 490 // The input tensor must be in host memory, since it is dereferenced to get 491 // the value. 492 Status MakeDimForScalarInput(int idx, DimensionHandle* out); 493 494 // Returns a new dimension whose value is given by a scalar input tensor. 495 // This allows for a negative input dimension given the rank of a separate 496 // tensor. This rank can be negative if unknown. 497 // The input tensor must be in host memory, since it is dereferenced to get 498 // the value. 499 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, 500 DimensionHandle* out); 501 502 // Look up the attr for the NodeDef being evaluated with name attr_name and 503 // set *value to its value. If no attr with attr_name is found in def(), or 504 // the attr does not have a matching type, a non-ok status will be returned. 505 template <class T> 506 Status GetAttr(StringPiece attr_name, T* value) const; 507 508 // Returns in <out> the result of dividing <dividend> by <divisor>. 509 // Returns an error if <divisor> is not positive or if <evenly_divisible> 510 // and <divisor> does not evenly divide <dividend>. 511 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, 512 bool evenly_divisible, DimensionHandle* out); 513 514 // Returns in <out> the sum of <first> and <second>. 515 Status Add(DimensionHandle first, DimensionOrConstant second, 516 DimensionHandle* out); 517 518 // Returns in <out> the dimension that is <first> minus <second>. 519 Status Subtract(DimensionHandle first, DimensionOrConstant second, 520 DimensionHandle* out); 521 522 // Returns in <out> the product of <first> and <second>. 523 Status Multiply(DimensionHandle first, DimensionOrConstant second, 524 DimensionHandle* out); 525 526 // Returns in <out> the minimum of <first> and <second>. If either <first> or 527 // <second> is zero the results is zero. Otherwise, if either <first> or 528 // <second> is unknown the results is unknown. 529 Status Min(DimensionHandle first, DimensionOrConstant second, 530 DimensionHandle* out); 531 532 // Returns in <out> the maximum of <first> and <second>. If either <first> or 533 // <second> is unknown the results is unknown. 534 Status Max(DimensionHandle first, DimensionOrConstant second, 535 DimensionHandle* out); 536 537 Status construction_status() const { return construction_status_; } 538 539 // Methods to propagate shape and dtype on edges of handles. Handles are the 540 // dtype DT_RESOURCE which can be used to access state stored in a 541 // ResourceManager. When ops (such as variables) consume these handles to 542 // produce tensors they might need to know side-information about the shapes 543 // and dtypes of tensors which can be accessed via the handle. These methods 544 // propagate that information. Output handle dtypes and shapes are ignored if 545 // the output tensor is not of type DT_RESOURCE. 546 547 // Merge the stored shapes and types corresponding to the input handle in 548 // position idx with the specified shapes and types. This requires idx to be 549 // in the [0, num_inputs) range. 550 // 551 // If the merge is successful and any of the new shapes differs from the old 552 // one, or any of the old dtypes was DT_INVALID, store the new shapes and 553 // return true. Return false otherwise. 554 // 555 // See 'MergeInput' function for full details and examples. 556 bool MergeInputHandleShapesAndTypes( 557 int idx, 558 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 559 560 // As MergeInputHandleShapesAndTypes, but for an output. 561 bool MergeOutputHandleShapesAndTypes( 562 int idx, 563 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 564 565 // Relaxes the stored shapes and types corresponding to the input handle in 566 // position idx with the specified shapes and types. This requires idx to be 567 // in the [0, num_inputs) range. 568 // 569 // If the relax is successful and any of the new shapes differs from the old 570 // one, or any of the old dtypes was DT_INVALID, store the new shapes and 571 // return true. Return false otherwise. 572 // 573 // See 'RelaxInput' function for full details and examples. 574 bool RelaxInputHandleShapesAndMergeTypes( 575 int idx, 576 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 577 578 // As RelaxInputHandleShapesAndTypes, but for an output. 579 bool RelaxOutputHandleShapesAndMergeTypes( 580 int idx, 581 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; 582 583 // Returns the output handle shapes and types, for the resource tensor output 584 // at index <idx>. Returns NULL if the shape and types were never set. 585 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) { 586 return output_handle_shapes_and_types_[idx].get(); 587 } 588 589 // Returns the inputs handle shapes and types, for the resource tensor output 590 // at index <idx>. Returns NULL if the shape and types were not available. 591 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) { 592 return input_handle_shapes_and_types_[idx].get(); 593 } 594 595 void set_output_handle_shapes_and_types( 596 int idx, const std::vector<ShapeAndType>& shapes_and_types) { 597 output_handle_shapes_and_types_[idx].reset( 598 new std::vector<ShapeAndType>(shapes_and_types)); 599 } 600 601 // Note that shape functions should usually call MakeShapeFromShapeTensor, 602 // as it does more analysis to provide partial shapes. 603 // 604 // Returns in <out> a new shape whose dimension sizes come from tensor <t>. 605 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL, 606 // then an unknown shape is returned. 607 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, 608 ShapeHandle* out); 609 610 int graph_def_version() const { return graph_def_version_; } 611 612 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const { 613 return merged_shapes_; 614 } 615 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims() 616 const { 617 return merged_dims_; 618 } 619 620 private: 621 // Creates and stores shapes for use in InferenceContext. 622 class ShapeManager { 623 public: 624 ShapeManager(); 625 ~ShapeManager(); 626 627 // Returns a new shape with the given dims. The returned value is owned by 628 // this class. 629 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); 630 631 // Returns a new unknown shape. 632 ShapeHandle UnknownShape(); 633 634 // Returns a new dimension of the given size. The returned value 635 // is owned by this class. 636 inline DimensionHandle MakeDim(DimensionOrConstant d) { 637 if (d.dim.IsSet()) { 638 return d.dim; 639 } else { 640 all_dims_.push_back(new Dimension(d.val)); 641 return all_dims_.back(); 642 } 643 } 644 645 private: 646 std::vector<Shape*> all_shapes_; // values are owned. 647 std::vector<Dimension*> all_dims_; // values are owned. 648 }; 649 650 friend class ::tensorflow::grappler::GraphProperties; 651 652 // Friend for user-defined function shape inference purposes. 653 friend class ::tensorflow::ShapeRefiner; 654 655 friend class ShapeInferenceTest; // For testing Relax functions. 656 friend class ShapeInferenceTestutil; // For testing shapes. 657 658 // Shared initialization across the two constructors. Remove 659 // once we get rid of one of them. 660 void PreInputInit(const OpDef& op_def, 661 const std::vector<const Tensor*>& input_tensors, 662 const std::vector<ShapeHandle>& input_tensors_as_shapes); 663 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 664 input_handle_data); 665 666 DimensionHandle GetDimension(const DimensionOrConstant& d); 667 668 Status ReturnUnknownShape(ShapeHandle* out) { 669 *out = UnknownShape(); 670 return Status::OK(); 671 } 672 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims, 673 ShapeHandle* out) { 674 *out = MakeShape(dims); 675 return Status::OK(); 676 } 677 678 // Adds additional context to the given status. 679 Status AttachContext(const Status& status); 680 681 // Relaxes an existing value <d_old> with a new value <d_new> and returns the 682 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible 683 // values, returns an error. 684 // 685 // Note that <*out> may be set to <d_old> or <d_new>. 686 void Relax(DimensionHandle d_old, DimensionHandle d_new, 687 DimensionHandle* out); 688 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the 689 // relaxed shape in <*out>. See 'RelaxInput' function for full details and 690 // examples. 691 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); 692 693 // Used to implement MergeInputHandleShapesAndTypes and 694 // MergeOutputHandleShapesAndTypes. 695 bool MergeHandleShapesAndTypes( 696 const std::vector<ShapeAndType>& shapes_and_types, 697 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; 698 // Used to implement RelaxInputHandleShapesAndMergeTypes and 699 // RelaxOutputHandleShapesAndMergeTypes. 700 bool RelaxHandleShapesAndMergeTypes( 701 const std::vector<ShapeAndType>& shapes_and_types, 702 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; 703 704 // Forget all the previous merged shapes and dims. 705 void ForgetMerges() { 706 merged_shapes_.clear(); 707 merged_dims_.clear(); 708 } 709 710 ShapeManager shape_manager_; 711 712 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from 713 // `shape_manager_`. 714 std::vector<ShapeHandle> inputs_; 715 std::vector<const Tensor*> input_tensors_; 716 std::vector<bool> requested_input_tensor_; 717 std::vector<ShapeHandle> outputs_; 718 // Can have fewer elements than inputs_. 719 std::vector<ShapeHandle> input_tensors_as_shapes_; 720 std::vector<bool> requested_input_tensor_as_partial_shape_; 721 722 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available 723 // through the resource handle passed along input i of the node. 724 // 725 // Values may be NULL. 726 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 727 input_handle_shapes_and_types_; 728 729 // output_handle_shapes_and_types_[i] is the list of shape/type pairs 730 // available through the resource handle passed along output i of the node. 731 // 732 // Values may be NULL. 733 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> 734 output_handle_shapes_and_types_; 735 736 const int graph_def_version_; 737 const NodeDef* node_def_; 738 NameRangeMap input_name_map_; 739 NameRangeMap output_name_map_; 740 741 // An error set during construction. TODO(cwhipkey): remove when test 742 // constructor is removed. 743 Status construction_status_; 744 745 // Pair of shape or dim handles that are equivalent, ie that represent the 746 // same underlying shape of dimension. Note that for each pair at least one of 747 // the handles must contain an unknown shape, since we don't keep track of 748 // known shapes or dims here. 749 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_; 750 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_; 751 752 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); 753 }; 754 755 // ----------------------------------------------------------------------------- 756 // Template and inline method implementations, please ignore 757 758 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} 759 inline Dimension::Dimension(int64 value) : value_(value) { 760 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) 761 << "Dimension must be non-negative or equal to " 762 "InferenceContext::kUnknownDim but got " 763 << value; 764 } 765 766 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} 767 inline Shape::Shape(const std::vector<DimensionHandle>& dims) 768 : rank_(dims.size()), dims_(dims) {} 769 770 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) 771 : dim(dim) { 772 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension."; 773 } 774 775 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) { 776 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) 777 << "Dimension must be non-negative or equal to " 778 "InferenceContext::kUnknownDim but got " 779 << val; 780 } 781 782 template <class T> 783 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { 784 return GetNodeAttr(*node_def_, attr_name, value); 785 } 786 787 } // namespace shape_inference 788 } // namespace tensorflow 789 790 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ 791