1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ 16 #define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ 17 18 #include <array> 19 20 #include "tensorflow/core/framework/shape_inference.h" 21 #include "tensorflow/core/util/padding.h" 22 #include "tensorflow/core/util/tensor_format.h" 23 24 namespace tensorflow { 25 26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding 27 // type, the function computes the output and padding dimensions. 28 // 29 // For example, ignoring batches or multiple features, a 1D convolution 30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of 31 // shape (K). 32 // 33 // It also takes in a few additional parameters: 34 // 35 // Stride (S): the stride with which we apply the filters. This is the offset 36 // between locations where we apply the filters. A larger stride 37 // means that the output will be spatially smaller. 38 // 39 // Padding (P): the padding we apply to the input tensor along each 40 // dimension. This is usually used to make sure that the spatial dimensions 41 // do not shrink when we progress with convolutions. Two types of padding are 42 // often used: 43 // SAME: the pad value is computed so that the output will have size H/S. 44 // VALID: no padding is carried out. 45 // The padded area is zero-filled. 46 // 47 // The output dimensions for convolution and many other operations, when given 48 // all the parameters above, are as follows: 49 // - When Padding = SAME: the output size is (H'), where 50 // H' = ceil(float(H) / float(S)) 51 // where ceil is the ceiling function. The number of padded cells 52 // is computed as: 53 // Pc = ((H' - 1) * S + K - H) / 2 54 // When the stride is 1, the expression simplifies to 55 // H' = H, Pc = (K-1)/2. 56 // This is where SAME comes from - the output has the same size as the input 57 // has. 58 // 59 // - When Padding = VALID: the output size is computed as 60 // H' = ceil(float(H - K + 1) / float(S)) 61 // and the number of padded cells is always zero. 62 // When the stride is 1, the expression simplifies to 63 // H' = H-K+1. 64 // 65 // For convolution, mathematically, the output value at location (r') 66 // is the inner product of two vectors: the chunk of input at 67 // ((r'*S-Pr) : (r'*S-Pr+K)), 68 // and the filter. 69 // 70 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the 71 // size and padding of each spatial dimension can be computed by calling 72 // GetWindowedOutputSize separately for each dimension. 73 // 74 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, 75 Padding padding_type, int64* output_size, 76 int64* padding_size); 77 78 // The V2 version computes the same outputs with arbitrary dilation_rate. 79 // The output dimensions are computed as follows: 80 // - When adding dilation_rate (D), we compute an effective filter size (K'): 81 // K' = (K - 1) * D + 1 82 // - When Padding = SAME: the output size is (H'), where 83 // H' = ceil(float(H) / float(S)) 84 // where ceil is the ceiling function. The number of padded cells 85 // is computed as: 86 // Pc = ((H' - 1) * S + K' - H) / 2 87 // When the stride is 1, the expression simplifies to 88 // H' = H, Pc = (K'-1)/2. 89 // This is where SAME comes from - the output has the same size as the input 90 // has. 91 // 92 // - When Padding = VALID: the output size is computed as 93 // H' = ceil(float(H - K' + 1) / float(S)) 94 // and the number of padded cells is always zero. 95 // When the stride is 1, the expression simplifies to 96 // H' = H-K'+1. 97 // 98 // TODO(b/67112639): Merge V2 versions and the original versions eventually. 99 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, 100 int64 dilation_rate, int64 stride, 101 Padding padding_type, int64* output_size, 102 int64* padding_size); 103 104 // Returns the same output dimensions as in GetWindowedOutputSize, but returns 105 // verbose padding dimensions (before/after). Any excess padding 106 // (caused by an odd padding size value) is added to the 'padding_after' 107 // dimension. 108 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, 109 int64 stride, Padding padding_type, 110 int64* output_size, int64* padding_before, 111 int64* padding_after); 112 113 // The V2 version computes the same outputs with arbitrary dilation_rate. For 114 // detailed equations, refer to the comments for GetWindowedOutputSizeV2(). 115 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, 116 int64 dilation_rate, int64 stride, 117 Padding padding_type, int64* output_size, 118 int64* padding_before, 119 int64* padding_after); 120 121 // Given an input tensor, kernel, stride and padding type, populates the 3D size 122 // of the output tensor and padding to be applied to the input tensor at the 123 // lower end of every dimension. Use for 3D convolutions, where the input data 124 // is padded with zeros, as well as for 3D avg/max pooling, where the input data 125 // is padded with invalid values that are not considered for pooling. 126 Status Get3dOutputSize(const std::array<int64, 3>& input, 127 const std::array<int64, 3>& window, 128 const std::array<int64, 3>& strides, 129 Padding padding_type, std::array<int64, 3>* output_ptr, 130 std::array<int64, 3>* padding_ptr); 131 132 // The V2 version computes the same outputs with arbitrary dilation_rate. For 133 // detailed equations, refer to the comments for GetWindowedOutputSizeV2(). 134 Status Get3dOutputSizeV2(const std::array<int64, 3>& input, 135 const std::array<int64, 3>& window, 136 const std::array<int64, 3>& dilations, 137 const std::array<int64, 3>& strides, 138 Padding padding_type, std::array<int64, 3>* output_ptr, 139 std::array<int64, 3>* padding_ptr); 140 141 namespace shape_inference { 142 143 // Like GetWindowedOutputSize, but deals with DimensionHandles. 144 Status GetWindowedOutputSizeFromDims(InferenceContext* c, 145 DimensionHandle input_size, 146 DimensionOrConstant filter_size, 147 int64 stride, Padding padding_type, 148 DimensionHandle* output_size); 149 150 // The V2 version computes the same outputs with arbitrary dilation_rate. For 151 // detailed equations, refer to the comments for GetWindowedOutputSizeV2(). 152 Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c, 153 DimensionHandle input_size, 154 DimensionOrConstant filter_size, 155 int64 dilation_rate, int64 stride, 156 Padding padding_type, 157 DimensionHandle* output_size); 158 159 // Transfers shape of input(0) to output(0). 160 Status UnchangedShape(shape_inference::InferenceContext* c); 161 162 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>. 163 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, 164 int32 rank) { 165 ShapeHandle out; 166 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); 167 c->set_output(0, out); 168 return Status::OK(); 169 } 170 171 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>. 172 inline Status UnchangedShapeWithRankAtLeast( 173 shape_inference::InferenceContext* c, int32 rank) { 174 ShapeHandle out; 175 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); 176 c->set_output(0, out); 177 return Status::OK(); 178 } 179 180 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>. 181 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, 182 int32 rank) { 183 ShapeHandle out; 184 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); 185 c->set_output(0, out); 186 return Status::OK(); 187 } 188 189 // Shape function for use with ops no outputs. 190 inline Status NoOutputs(shape_inference::InferenceContext* c) { 191 return Status::OK(); 192 } 193 194 // Shape function for ops that output a single scalar value. 195 inline Status ScalarShape(shape_inference::InferenceContext* c) { 196 c->set_output(0, c->Scalar()); 197 return Status::OK(); 198 } 199 200 // Shape function for binary ops where both inputs and the output match. 201 inline Status MergeBothInputsShapeFn(InferenceContext* c) { 202 ShapeHandle out; 203 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); 204 c->set_output(0, out); 205 return Status::OK(); 206 } 207 208 // Returns a new shape with the specified dims arranged in the specified 209 // format. The returned value is owned by this context. 210 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. 211 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, 212 const std::vector<DimensionOrConstant>& spatial, 213 DimensionOrConstant C, ShapeHandle* out, 214 shape_inference::InferenceContext* context); 215 216 // Shape function for MatMul-like operations. 217 Status MatMulShape(shape_inference::InferenceContext* c); 218 219 // Shape function for BiasAdd-like operations. 220 Status BiasAddShape(shape_inference::InferenceContext* c); 221 222 // Shape function for BiasAddGrad-like operations. 223 Status BiasAddGradShape(shape_inference::InferenceContext* c); 224 225 // Shape function for Conv2D-like operations. 226 Status Conv2DShape(shape_inference::InferenceContext* c); 227 228 // Shape function for Conv3D-like operations. 229 Status Conv3DShape(shape_inference::InferenceContext* c); 230 231 // Shape function for DepthwiseConv2D-like operations. 232 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); 233 234 // Shape function for AvgPool-like operations. 235 Status AvgPoolShape(shape_inference::InferenceContext* c); 236 237 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations. 238 Status FusedBatchNormShape(shape_inference::InferenceContext* c); 239 240 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. 241 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); 242 243 // Shape function for MaxPool-like operations. 244 Status MaxPoolShape(shape_inference::InferenceContext* c); 245 246 // Shape function for MaxPoolV2-like operations. 247 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); 248 249 // Shape function for 3D Pooling operations. 250 Status Pool3DShape(shape_inference::InferenceContext* c); 251 252 // Shape function for use with ops whose output shapes are unknown. 253 Status UnknownShape(shape_inference::InferenceContext* c); 254 255 // Shape function for reduction operations. 256 Status ReductionShape(shape_inference::InferenceContext* c); 257 258 // Shape function for concat operations. 259 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken 260 // from inputs 261 // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. 262 Status ConcatShape(shape_inference::InferenceContext* c, 263 int num_inputs_to_concat); 264 265 // Shape function for concat operations. 266 Status ConcatV2Shape(shape_inference::InferenceContext* c); 267 268 // Shape function for binary operators that broadcast their inputs. 269 // Tested by ops/math_ops_test.cc. 270 Status BroadcastBinaryOpShapeFn(InferenceContext* c); 271 272 // Shape function for random operations. 273 Status RandomShape(shape_inference::InferenceContext* c); 274 275 // Validates the 3 component tensors of a sparse tensor have the proper 276 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. 277 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, 278 ShapeHandle values_shape, ShapeHandle shape_shape); 279 280 // Shape function for ScatterNd update/add/sub/... operations. 281 Status ScatterNdUpdateShape(InferenceContext* c); 282 283 // Shape function for ops with an explicit "shape" attribute. 284 Status ExplicitShape(InferenceContext* c); 285 286 } // namespace shape_inference 287 288 } // namespace tensorflow 289 290 #endif // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ 291