Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
     16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
     17 
     18 #include <array>
     19 
     20 #include "tensorflow/core/framework/shape_inference.h"
     21 #include "tensorflow/core/util/padding.h"
     22 #include "tensorflow/core/util/tensor_format.h"
     23 
     24 namespace tensorflow {
     25 
     26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
     27 // type, the function computes the output and padding dimensions.
     28 //
     29 // For example, ignoring batches or multiple features, a 1D convolution
     30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of
     31 // shape (K).
     32 //
     33 // It also takes in a few additional parameters:
     34 //
     35 // Stride (S): the stride with which we apply the filters. This is the offset
     36 // between locations where we apply the filters. A larger stride
     37 // means that the output will be spatially smaller.
     38 //
     39 // Padding (P): the padding we apply to the input tensor along each
     40 // dimension. This is usually used to make sure that the spatial dimensions
     41 // do not shrink when we progress with convolutions. This function supports two
     42 // types of padding.
     43 //   SAME: the pad value is computed so that the output will have size H/S.
     44 //   VALID: no padding is carried out.
     45 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerbose must be
     46 // called instead. Note the padded area is zero-filled.
     47 //
     48 // The output dimensions for convolution and many other operations, when given
     49 // all the parameters above, are as follows:
     50 // - When Padding = SAME: the output size is (H'), where
     51 //     H' = ceil(float(H) / float(S))
     52 //   where ceil is the ceiling function. The number of padded cells
     53 //   is computed as:
     54 //     Pc = ((H' - 1) * S + K - H) / 2
     55 //   When the stride is 1, the expression simplifies to
     56 //     H' = H, Pc = (K-1)/2.
     57 //   This is where SAME comes from - the output has the same size as the input
     58 //   has.
     59 //
     60 // - When Padding = VALID: the output size is computed as
     61 //     H' = ceil(float(H - K + 1) / float(S))
     62 //   and the number of padded cells is always zero.
     63 //   When the stride is 1, the expression simplifies to
     64 //     H' = H-K+1.
     65 //
     66 // For convolution, mathematically, the output value at location (r')
     67 // is the inner product of two vectors: the chunk of input at
     68 //    ((r'*S-Pr) : (r'*S-Pr+K)),
     69 // and the filter.
     70 //
     71 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
     72 // size and padding of each spatial dimension can be computed by calling
     73 // GetWindowedOutputSize separately for each dimension.
     74 //
     75 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
     76                              Padding padding_type, int64* output_size,
     77                              int64* padding_size);
     78 
     79 // The V2 version computes the same outputs with arbitrary dilation_rate.
     80 // The output dimensions are computed as follows:
     81 // - When adding dilation_rate (D), we compute an effective filter size (K'):
     82 //     K' = (K - 1) * D + 1
     83 // - When Padding = SAME: the output size is (H'), where
     84 //     H' = ceil(float(H) / float(S))
     85 //   where ceil is the ceiling function. The number of padded cells
     86 //   is computed as:
     87 //     Pc = ((H' - 1) * S + K' - H) / 2
     88 //   When the stride is 1, the expression simplifies to
     89 //     H' = H, Pc = (K'-1)/2.
     90 //   This is where SAME comes from - the output has the same size as the input
     91 //   has.
     92 //
     93 // - When Padding = VALID: the output size is computed as
     94 //     H' = ceil(float(H - K' + 1) / float(S))
     95 //   and the number of padded cells is always zero.
     96 //   When the stride is 1, the expression simplifies to
     97 //     H' = H-K'+1.
     98 //
     99 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerboseV2 must be
    100 // called instead
    101 //
    102 // TODO(b/67112639): Merge V2 versions and the original versions eventually.
    103 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
    104                                int64 dilation_rate, int64 stride,
    105                                Padding padding_type, int64* output_size,
    106                                int64* padding_size);
    107 
    108 // Returns the same output dimensions as in GetWindowedOutputSize, but returns
    109 // verbose padding dimensions (before/after), and EXPLICIT padding is supported.
    110 // When padding_type is EXPLICIT, *padding_before and *padding_after must
    111 // already point to initialized integers with the padding amounts. Otherwise,
    112 // *padding_before and *padding_after are set by this function, and any
    113 // excess padding (caused by an odd padding size value) is added to the
    114 // 'padding_after' dimension.
    115 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
    116                                     int64 stride, Padding padding_type,
    117                                     int64* output_size, int64* padding_before,
    118                                     int64* padding_after);
    119 
    120 // The V2 version computes the same outputs with arbitrary dilation_rate. For
    121 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
    122 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
    123                                       int64 dilation_rate, int64 stride,
    124                                       Padding padding_type, int64* output_size,
    125                                       int64* padding_before,
    126                                       int64* padding_after);
    127 
    128 // Given an input tensor, kernel, stride and padding type, populates the 3D size
    129 // of the output tensor and padding to be applied to the input tensor at the
    130 // lower end of every dimension. Use for 3D convolutions, where the input data
    131 // is padded with zeros, as well as for 3D avg/max pooling, where the input data
    132 // is padded with invalid values that are not considered for pooling. EXPLICIT
    133 // padding is not supported.
    134 Status Get3dOutputSize(const std::array<int64, 3>& input,
    135                        const std::array<int64, 3>& window,
    136                        const std::array<int64, 3>& strides,
    137                        Padding padding_type, std::array<int64, 3>* output_ptr,
    138                        std::array<int64, 3>* padding_ptr);
    139 
    140 // The V2 version computes the same outputs with arbitrary dilation_rate. For
    141 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
    142 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
    143                          const std::array<int64, 3>& window,
    144                          const std::array<int64, 3>& dilations,
    145                          const std::array<int64, 3>& strides,
    146                          Padding padding_type, std::array<int64, 3>* output_ptr,
    147                          std::array<int64, 3>* padding_ptr);
    148 
    149 namespace shape_inference {
    150 
    151 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
    152 // EXPLICIT padding.
    153 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
    154                                      DimensionHandle input_size,
    155                                      DimensionOrConstant filter_size,
    156                                      int64 stride, Padding padding_type,
    157                                      DimensionHandle* output_size);
    158 
    159 // The V2 version computes the same outputs with arbitrary dilation_rate, and
    160 // supports EXPLICIT padding. For detailed equations, refer to the comments
    161 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
    162 // parameters are only used if padding_type == EXPLICIT.
    163 Status GetWindowedOutputSizeFromDimsV2(
    164     InferenceContext* c, DimensionHandle input_size,
    165     DimensionOrConstant filter_size, int64 dilation_rate, int64 stride,
    166     Padding padding_type, int64 padding_before, int64 padding_after,
    167     DimensionHandle* output_size);
    168 
    169 // Transfers shape of input(0) to output(0).
    170 Status UnchangedShape(shape_inference::InferenceContext* c);
    171 
    172 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
    173 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
    174                                      int32 rank) {
    175   ShapeHandle out;
    176   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
    177   c->set_output(0, out);
    178   return Status::OK();
    179 }
    180 
    181 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
    182 inline Status UnchangedShapeWithRankAtLeast(
    183     shape_inference::InferenceContext* c, int32 rank) {
    184   ShapeHandle out;
    185   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
    186   c->set_output(0, out);
    187   return Status::OK();
    188 }
    189 
    190 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
    191 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
    192                                            int32 rank) {
    193   ShapeHandle out;
    194   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
    195   c->set_output(0, out);
    196   return Status::OK();
    197 }
    198 
    199 // Shape function for use with ops no outputs.
    200 inline Status NoOutputs(shape_inference::InferenceContext* c) {
    201   return Status::OK();
    202 }
    203 
    204 // Shape function for ops that output a single scalar value.
    205 inline Status ScalarShape(shape_inference::InferenceContext* c) {
    206   c->set_output(0, c->Scalar());
    207   return Status::OK();
    208 }
    209 
    210 // Shape function for binary ops where both inputs and the output match.
    211 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
    212   ShapeHandle out;
    213   TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
    214   c->set_output(0, out);
    215   return Status::OK();
    216 }
    217 
    218 // Returns a new shape with the specified dims arranged in the specified
    219 // format. The returned value is owned by this context.
    220 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
    221 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
    222                            const std::vector<DimensionOrConstant>& spatial,
    223                            DimensionOrConstant C, ShapeHandle* out,
    224                            shape_inference::InferenceContext* context);
    225 
    226 // Shape function for MatMul-like operations.
    227 Status MatMulShape(shape_inference::InferenceContext* c);
    228 
    229 // Shape function for BiasAdd-like operations.
    230 Status BiasAddShape(shape_inference::InferenceContext* c);
    231 
    232 // Shape function for BiasAddGrad-like operations.
    233 Status BiasAddGradShape(shape_inference::InferenceContext* c);
    234 
    235 // Shape function for Conv2D-like operations that support explicit padding.
    236 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
    237 
    238 // Shape function for Conv2D-like operations that do not support explicit
    239 // padding.
    240 Status Conv2DShape(shape_inference::InferenceContext* c);
    241 
    242 // Shape function for Conv3D-like operations.
    243 Status Conv3DShape(shape_inference::InferenceContext* c);
    244 
    245 // Shape function for DepthwiseConv2D-like operations.
    246 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
    247 
    248 // Shape function for AvgPool-like operations.
    249 Status AvgPoolShape(shape_inference::InferenceContext* c);
    250 
    251 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
    252 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
    253 
    254 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
    255 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
    256 
    257 // Shape function for MaxPool-like operations.
    258 Status MaxPoolShape(shape_inference::InferenceContext* c);
    259 
    260 // Shape function for MaxPoolV2-like operations.
    261 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
    262 
    263 // Shape function for 3D Pooling operations.
    264 Status Pool3DShape(shape_inference::InferenceContext* c);
    265 
    266 // Shape function for use with ops whose output shapes are unknown.
    267 Status UnknownShape(shape_inference::InferenceContext* c);
    268 
    269 // Shape function for reduction operations.
    270 Status ReductionShape(shape_inference::InferenceContext* c);
    271 
    272 // Shape function for concat operations.
    273 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
    274 // from inputs
    275 // [1,num_inputs_to_concat] of the op.  Input 0 is the concat_dim input.
    276 Status ConcatShape(shape_inference::InferenceContext* c,
    277                    int num_inputs_to_concat);
    278 
    279 // Shape function for concat operations.
    280 Status ConcatV2Shape(shape_inference::InferenceContext* c);
    281 
    282 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
    283 
    284 // Shape function for binary operators that broadcast their inputs
    285 // and with output to output_index.
    286 // Note: out cannot be NULL.
    287 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
    288                                             ShapeHandle shape_x,
    289                                             ShapeHandle shape_y,
    290                                             ShapeHandle* out);
    291 
    292 // Shape function for binary operators that broadcast their inputs
    293 // and with output to output_index.
    294 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
    295                                              int output_index) {
    296   ShapeHandle out;
    297   TF_RETURN_IF_ERROR(
    298       BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
    299   c->set_output(output_index, out);
    300   return Status::OK();
    301 }
    302 
    303 // Shape function for binary operators that broadcast their inputs.
    304 // Tested by ops/math_ops_test.cc.
    305 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
    306   return BroadcastBinaryOpOutputShapeFn(c, 0);
    307 }
    308 
    309 // Shape function for random operations.
    310 Status RandomShape(shape_inference::InferenceContext* c);
    311 
    312 // Shape function for Slice opertaions.
    313 Status SliceShape(shape_inference::InferenceContext* c);
    314 
    315 // Validates the 3 component tensors of a sparse tensor have the proper
    316 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
    317 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
    318                             ShapeHandle values_shape, ShapeHandle shape_shape);
    319 
    320 // Shape function for ScatterNd update/add/sub/... operations.
    321 Status ScatterNdUpdateShape(InferenceContext* c);
    322 
    323 // Shape function for ops with an explicit "shape" attribute.
    324 Status ExplicitShape(InferenceContext* c);
    325 
    326 // Shape function for multiple-output ops with an explicit "shapes" attribute.
    327 Status ExplicitShapes(InferenceContext* c);
    328 
    329 // Shape function for SparseReduceMax and SparseReduceSum.
    330 Status SparseReduceShapeFn(InferenceContext* c);
    331 
    332 }  // namespace shape_inference
    333 
    334 }  // namespace tensorflow
    335 
    336 #endif  // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
    337