1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ 17 18 #include <array> 19 20 #include "tensorflow/core/framework/shape_inference.h" 21 #include "tensorflow/core/util/padding.h" 22 #include "tensorflow/core/util/tensor_format.h" 23 24 namespace tensorflow { 25 26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding 27 // type, the function computes the output and padding dimensions. 28 // 29 // For example, ignoring batches or multiple features, a 1D convolution 30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of 31 // shape (K). 32 // 33 // It also takes in a few additional parameters: 34 // 35 // Stride (S): the stride with which we apply the filters. This is the offset 36 // between locations where we apply the filters. A larger stride 37 // means that the output will be spatially smaller. 38 // 39 // Padding (P): the padding we apply to the input tensor along each 40 // dimension. This is usually used to make sure that the spatial dimensions 41 // do not shrink when we progress with convolutions. This function supports two 42 // types of padding. 43 // SAME: the pad value is computed so that the output will have size H/S. 44 // VALID: no padding is carried out. 45 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerbose must be 46 // called instead. Note the padded area is zero-filled. 47 // 48 // The output dimensions for convolution and many other operations, when given 49 // all the parameters above, are as follows: 50 // - When Padding = SAME: the output size is (H'), where 51 // H' = ceil(float(H) / float(S)) 52 // where ceil is the ceiling function. The number of padded cells 53 // is computed as: 54 // Pc = ((H' - 1) * S + K - H) / 2 55 // When the stride is 1, the expression simplifies to 56 // H' = H, Pc = (K-1)/2. 57 // This is where SAME comes from - the output has the same size as the input 58 // has. 59 // 60 // - When Padding = VALID: the output size is computed as 61 // H' = ceil(float(H - K + 1) / float(S)) 62 // and the number of padded cells is always zero. 63 // When the stride is 1, the expression simplifies to 64 // H' = H-K+1. 65 // 66 // For convolution, mathematically, the output value at location (r') 67 // is the inner product of two vectors: the chunk of input at 68 // ((r'*S-Pr) : (r'*S-Pr+K)), 69 // and the filter. 70 // 71 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the 72 // size and padding of each spatial dimension can be computed by calling 73 // GetWindowedOutputSize separately for each dimension. 74 // 75 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, 76 Padding padding_type, int64* output_size, 77 int64* padding_size); 78 79 // The V2 version computes the same outputs with arbitrary dilation_rate. 80 // The output dimensions are computed as follows: 81 // - When adding dilation_rate (D), we compute an effective filter size (K'): 82 // K' = (K - 1) * D + 1 83 // - When Padding = SAME: the output size is (H'), where 84 // H' = ceil(float(H) / float(S)) 85 // where ceil is the ceiling function. The number of padded cells 86 // is computed as: 87 // Pc = ((H' - 1) * S + K' - H) / 2 88 // When the stride is 1, the expression simplifies to 89 // H' = H, Pc = (K'-1)/2. 90 // This is where SAME comes from - the output has the same size as the input 91 // has. 92 // 93 // - When Padding = VALID: the output size is computed as 94 // H' = ceil(float(H - K' + 1) / float(S)) 95 // and the number of padded cells is always zero. 96 // When the stride is 1, the expression simplifies to 97 // H' = H-K'+1. 98 // 99 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerboseV2 must be 100 // called instead 101 // 102 // TODO(b/67112639): Merge V2 versions and the original versions eventually. 103 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, 104 int64 dilation_rate, int64 stride, 105 Padding padding_type, int64* output_size, 106 int64* padding_size); 107 108 // Returns the same output dimensions as in GetWindowedOutputSize, but returns 109 // verbose padding dimensions (before/after), and EXPLICIT padding is supported. 110 // When padding_type is EXPLICIT, *padding_before and *padding_after must 111 // already point to initialized integers with the padding amounts. Otherwise, 112 // *padding_before and *padding_after are set by this function, and any 113 // excess padding (caused by an odd padding size value) is added to the 114 // 'padding_after' dimension. 115 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, 116 int64 stride, Padding padding_type, 117 int64* output_size, int64* padding_before, 118 int64* padding_after); 119 120 // The V2 version computes the same outputs with arbitrary dilation_rate. For 121 // detailed equations, refer to the comments for GetWindowedOutputSizeV2(). 122 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, 123 int64 dilation_rate, int64 stride, 124 Padding padding_type, int64* output_size, 125 int64* padding_before, 126 int64* padding_after); 127 128 // Given an input tensor, kernel, stride and padding type, populates the 3D size 129 // of the output tensor and padding to be applied to the input tensor at the 130 // lower end of every dimension. Use for 3D convolutions, where the input data 131 // is padded with zeros, as well as for 3D avg/max pooling, where the input data 132 // is padded with invalid values that are not considered for pooling. EXPLICIT 133 // padding is not supported. 134 Status Get3dOutputSize(const std::array<int64, 3>& input, 135 const std::array<int64, 3>& window, 136 const std::array<int64, 3>& strides, 137 Padding padding_type, std::array<int64, 3>* output_ptr, 138 std::array<int64, 3>* padding_ptr); 139 140 // The V2 version computes the same outputs with arbitrary dilation_rate. For 141 // detailed equations, refer to the comments for GetWindowedOutputSizeV2(). 142 Status Get3dOutputSizeV2(const std::array<int64, 3>& input, 143 const std::array<int64, 3>& window, 144 const std::array<int64, 3>& dilations, 145 const std::array<int64, 3>& strides, 146 Padding padding_type, std::array<int64, 3>* output_ptr, 147 std::array<int64, 3>* padding_ptr); 148 149 namespace shape_inference { 150 151 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support 152 // EXPLICIT padding. 153 Status GetWindowedOutputSizeFromDims(InferenceContext* c, 154 DimensionHandle input_size, 155 DimensionOrConstant filter_size, 156 int64 stride, Padding padding_type, 157 DimensionHandle* output_size); 158 159 // The V2 version computes the same outputs with arbitrary dilation_rate, and 160 // supports EXPLICIT padding. For detailed equations, refer to the comments 161 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after' 162 // parameters are only used if padding_type == EXPLICIT. 163 Status GetWindowedOutputSizeFromDimsV2( 164 InferenceContext* c, DimensionHandle input_size, 165 DimensionOrConstant filter_size, int64 dilation_rate, int64 stride, 166 Padding padding_type, int64 padding_before, int64 padding_after, 167 DimensionHandle* output_size); 168 169 // Transfers shape of input(0) to output(0). 170 Status UnchangedShape(shape_inference::InferenceContext* c); 171 172 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>. 173 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, 174 int32 rank) { 175 ShapeHandle out; 176 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); 177 c->set_output(0, out); 178 return Status::OK(); 179 } 180 181 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>. 182 inline Status UnchangedShapeWithRankAtLeast( 183 shape_inference::InferenceContext* c, int32 rank) { 184 ShapeHandle out; 185 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); 186 c->set_output(0, out); 187 return Status::OK(); 188 } 189 190 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>. 191 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, 192 int32 rank) { 193 ShapeHandle out; 194 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); 195 c->set_output(0, out); 196 return Status::OK(); 197 } 198 199 // Shape function for use with ops no outputs. 200 inline Status NoOutputs(shape_inference::InferenceContext* c) { 201 return Status::OK(); 202 } 203 204 // Shape function for ops that output a single scalar value. 205 inline Status ScalarShape(shape_inference::InferenceContext* c) { 206 c->set_output(0, c->Scalar()); 207 return Status::OK(); 208 } 209 210 // Shape function for binary ops where both inputs and the output match. 211 inline Status MergeBothInputsShapeFn(InferenceContext* c) { 212 ShapeHandle out; 213 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); 214 c->set_output(0, out); 215 return Status::OK(); 216 } 217 218 // Returns a new shape with the specified dims arranged in the specified 219 // format. The returned value is owned by this context. 220 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. 221 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, 222 const std::vector<DimensionOrConstant>& spatial, 223 DimensionOrConstant C, ShapeHandle* out, 224 shape_inference::InferenceContext* context); 225 226 // Shape function for MatMul-like operations. 227 Status MatMulShape(shape_inference::InferenceContext* c); 228 229 // Shape function for BiasAdd-like operations. 230 Status BiasAddShape(shape_inference::InferenceContext* c); 231 232 // Shape function for BiasAddGrad-like operations. 233 Status BiasAddGradShape(shape_inference::InferenceContext* c); 234 235 // Shape function for Conv2D-like operations that support explicit padding. 236 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c); 237 238 // Shape function for Conv2D-like operations that do not support explicit 239 // padding. 240 Status Conv2DShape(shape_inference::InferenceContext* c); 241 242 // Shape function for Conv3D-like operations. 243 Status Conv3DShape(shape_inference::InferenceContext* c); 244 245 // Shape function for DepthwiseConv2D-like operations. 246 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); 247 248 // Shape function for AvgPool-like operations. 249 Status AvgPoolShape(shape_inference::InferenceContext* c); 250 251 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations. 252 Status FusedBatchNormShape(shape_inference::InferenceContext* c); 253 254 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. 255 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); 256 257 // Shape function for MaxPool-like operations. 258 Status MaxPoolShape(shape_inference::InferenceContext* c); 259 260 // Shape function for MaxPoolV2-like operations. 261 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); 262 263 // Shape function for 3D Pooling operations. 264 Status Pool3DShape(shape_inference::InferenceContext* c); 265 266 // Shape function for use with ops whose output shapes are unknown. 267 Status UnknownShape(shape_inference::InferenceContext* c); 268 269 // Shape function for reduction operations. 270 Status ReductionShape(shape_inference::InferenceContext* c); 271 272 // Shape function for concat operations. 273 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken 274 // from inputs 275 // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. 276 Status ConcatShape(shape_inference::InferenceContext* c, 277 int num_inputs_to_concat); 278 279 // Shape function for concat operations. 280 Status ConcatV2Shape(shape_inference::InferenceContext* c); 281 282 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat); 283 284 // Shape function for binary operators that broadcast their inputs 285 // and with output to output_index. 286 // Note: out cannot be NULL. 287 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, 288 ShapeHandle shape_x, 289 ShapeHandle shape_y, 290 ShapeHandle* out); 291 292 // Shape function for binary operators that broadcast their inputs 293 // and with output to output_index. 294 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, 295 int output_index) { 296 ShapeHandle out; 297 TF_RETURN_IF_ERROR( 298 BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out)); 299 c->set_output(output_index, out); 300 return Status::OK(); 301 } 302 303 // Shape function for binary operators that broadcast their inputs. 304 // Tested by ops/math_ops_test.cc. 305 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { 306 return BroadcastBinaryOpOutputShapeFn(c, 0); 307 } 308 309 // Shape function for random operations. 310 Status RandomShape(shape_inference::InferenceContext* c); 311 312 // Shape function for Slice opertaions. 313 Status SliceShape(shape_inference::InferenceContext* c); 314 315 // Validates the 3 component tensors of a sparse tensor have the proper 316 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. 317 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, 318 ShapeHandle values_shape, ShapeHandle shape_shape); 319 320 // Shape function for ScatterNd update/add/sub/... operations. 321 Status ScatterNdUpdateShape(InferenceContext* c); 322 323 // Shape function for ops with an explicit "shape" attribute. 324 Status ExplicitShape(InferenceContext* c); 325 326 // Shape function for multiple-output ops with an explicit "shapes" attribute. 327 Status ExplicitShapes(InferenceContext* c); 328 329 // Shape function for SparseReduceMax and SparseReduceSum. 330 Status SparseReduceShapeFn(InferenceContext* c); 331 332 } // namespace shape_inference 333 334 } // namespace tensorflow 335 336 #endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ 337