Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
     16 #define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
     17 
     18 #include <array>
     19 
     20 #include "tensorflow/core/framework/shape_inference.h"
     21 #include "tensorflow/core/util/padding.h"
     22 #include "tensorflow/core/util/tensor_format.h"
     23 
     24 namespace tensorflow {
     25 
     26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
     27 // type, the function computes the output and padding dimensions.
     28 //
     29 // For example, ignoring batches or multiple features, a 1D convolution
     30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of
     31 // shape (K).
     32 //
     33 // It also takes in a few additional parameters:
     34 //
     35 // Stride (S): the stride with which we apply the filters. This is the offset
     36 // between locations where we apply the filters. A larger stride
     37 // means that the output will be spatially smaller.
     38 //
     39 // Padding (P): the padding we apply to the input tensor along each
     40 // dimension. This is usually used to make sure that the spatial dimensions
     41 // do not shrink when we progress with convolutions. Two types of padding are
     42 // often used:
     43 //   SAME: the pad value is computed so that the output will have size H/S.
     44 //   VALID: no padding is carried out.
     45 // The padded area is zero-filled.
     46 //
     47 // The output dimensions for convolution and many other operations, when given
     48 // all the parameters above, are as follows:
     49 // - When Padding = SAME: the output size is (H'), where
     50 //     H' = ceil(float(H) / float(S))
     51 //   where ceil is the ceiling function. The number of padded cells
     52 //   is computed as:
     53 //     Pc = ((H' - 1) * S + K - H) / 2
     54 //   When the stride is 1, the expression simplifies to
     55 //     H' = H, Pc = (K-1)/2.
     56 //   This is where SAME comes from - the output has the same size as the input
     57 //   has.
     58 //
     59 // - When Padding = VALID: the output size is computed as
     60 //     H' = ceil(float(H - K + 1) / float(S))
     61 //   and the number of padded cells is always zero.
     62 //   When the stride is 1, the expression simplifies to
     63 //     H' = H-K+1.
     64 //
     65 // For convolution, mathematically, the output value at location (r')
     66 // is the inner product of two vectors: the chunk of input at
     67 //    ((r'*S-Pr) : (r'*S-Pr+K)),
     68 // and the filter.
     69 //
     70 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
     71 // size and padding of each spatial dimension can be computed by calling
     72 // GetWindowedOutputSize separately for each dimension.
     73 //
     74 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
     75                              Padding padding_type, int64* output_size,
     76                              int64* padding_size);
     77 
     78 // The V2 version computes the same outputs with arbitrary dilation_rate.
     79 // The output dimensions are computed as follows:
     80 // - When adding dilation_rate (D), we compute an effective filter size (K'):
     81 //     K' = (K - 1) * D + 1
     82 // - When Padding = SAME: the output size is (H'), where
     83 //     H' = ceil(float(H) / float(S))
     84 //   where ceil is the ceiling function. The number of padded cells
     85 //   is computed as:
     86 //     Pc = ((H' - 1) * S + K' - H) / 2
     87 //   When the stride is 1, the expression simplifies to
     88 //     H' = H, Pc = (K'-1)/2.
     89 //   This is where SAME comes from - the output has the same size as the input
     90 //   has.
     91 //
     92 // - When Padding = VALID: the output size is computed as
     93 //     H' = ceil(float(H - K' + 1) / float(S))
     94 //   and the number of padded cells is always zero.
     95 //   When the stride is 1, the expression simplifies to
     96 //     H' = H-K'+1.
     97 //
     98 // TODO(b/67112639): Merge V2 versions and the original versions eventually.
     99 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
    100                                int64 dilation_rate, int64 stride,
    101                                Padding padding_type, int64* output_size,
    102                                int64* padding_size);
    103 
    104 // Returns the same output dimensions as in GetWindowedOutputSize, but returns
    105 // verbose padding dimensions (before/after). Any excess padding
    106 // (caused by an odd padding size value) is added to the 'padding_after'
    107 // dimension.
    108 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
    109                                     int64 stride, Padding padding_type,
    110                                     int64* output_size, int64* padding_before,
    111                                     int64* padding_after);
    112 
    113 // The V2 version computes the same outputs with arbitrary dilation_rate. For
    114 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
    115 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
    116                                       int64 dilation_rate, int64 stride,
    117                                       Padding padding_type, int64* output_size,
    118                                       int64* padding_before,
    119                                       int64* padding_after);
    120 
    121 // Given an input tensor, kernel, stride and padding type, populates the 3D size
    122 // of the output tensor and padding to be applied to the input tensor at the
    123 // lower end of every dimension. Use for 3D convolutions, where the input data
    124 // is padded with zeros, as well as for 3D avg/max pooling, where the input data
    125 // is padded with invalid values that are not considered for pooling.
    126 Status Get3dOutputSize(const std::array<int64, 3>& input,
    127                        const std::array<int64, 3>& window,
    128                        const std::array<int64, 3>& strides,
    129                        Padding padding_type, std::array<int64, 3>* output_ptr,
    130                        std::array<int64, 3>* padding_ptr);
    131 
    132 // The V2 version computes the same outputs with arbitrary dilation_rate. For
    133 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
    134 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
    135                          const std::array<int64, 3>& window,
    136                          const std::array<int64, 3>& dilations,
    137                          const std::array<int64, 3>& strides,
    138                          Padding padding_type, std::array<int64, 3>* output_ptr,
    139                          std::array<int64, 3>* padding_ptr);
    140 
    141 namespace shape_inference {
    142 
    143 // Like GetWindowedOutputSize, but deals with DimensionHandles.
    144 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
    145                                      DimensionHandle input_size,
    146                                      DimensionOrConstant filter_size,
    147                                      int64 stride, Padding padding_type,
    148                                      DimensionHandle* output_size);
    149 
    150 // The V2 version computes the same outputs with arbitrary dilation_rate. For
    151 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
    152 Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c,
    153                                        DimensionHandle input_size,
    154                                        DimensionOrConstant filter_size,
    155                                        int64 dilation_rate, int64 stride,
    156                                        Padding padding_type,
    157                                        DimensionHandle* output_size);
    158 
    159 // Transfers shape of input(0) to output(0).
    160 Status UnchangedShape(shape_inference::InferenceContext* c);
    161 
    162 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
    163 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
    164                                      int32 rank) {
    165   ShapeHandle out;
    166   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
    167   c->set_output(0, out);
    168   return Status::OK();
    169 }
    170 
    171 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
    172 inline Status UnchangedShapeWithRankAtLeast(
    173     shape_inference::InferenceContext* c, int32 rank) {
    174   ShapeHandle out;
    175   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
    176   c->set_output(0, out);
    177   return Status::OK();
    178 }
    179 
    180 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
    181 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
    182                                            int32 rank) {
    183   ShapeHandle out;
    184   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
    185   c->set_output(0, out);
    186   return Status::OK();
    187 }
    188 
    189 // Shape function for use with ops no outputs.
    190 inline Status NoOutputs(shape_inference::InferenceContext* c) {
    191   return Status::OK();
    192 }
    193 
    194 // Shape function for ops that output a single scalar value.
    195 inline Status ScalarShape(shape_inference::InferenceContext* c) {
    196   c->set_output(0, c->Scalar());
    197   return Status::OK();
    198 }
    199 
    200 // Shape function for binary ops where both inputs and the output match.
    201 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
    202   ShapeHandle out;
    203   TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
    204   c->set_output(0, out);
    205   return Status::OK();
    206 }
    207 
    208 // Returns a new shape with the specified dims arranged in the specified
    209 // format. The returned value is owned by this context.
    210 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
    211 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
    212                            const std::vector<DimensionOrConstant>& spatial,
    213                            DimensionOrConstant C, ShapeHandle* out,
    214                            shape_inference::InferenceContext* context);
    215 
    216 // Shape function for MatMul-like operations.
    217 Status MatMulShape(shape_inference::InferenceContext* c);
    218 
    219 // Shape function for BiasAdd-like operations.
    220 Status BiasAddShape(shape_inference::InferenceContext* c);
    221 
    222 // Shape function for BiasAddGrad-like operations.
    223 Status BiasAddGradShape(shape_inference::InferenceContext* c);
    224 
    225 // Shape function for Conv2D-like operations.
    226 Status Conv2DShape(shape_inference::InferenceContext* c);
    227 
    228 // Shape function for Conv3D-like operations.
    229 Status Conv3DShape(shape_inference::InferenceContext* c);
    230 
    231 // Shape function for DepthwiseConv2D-like operations.
    232 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
    233 
    234 // Shape function for AvgPool-like operations.
    235 Status AvgPoolShape(shape_inference::InferenceContext* c);
    236 
    237 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
    238 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
    239 
    240 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
    241 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
    242 
    243 // Shape function for MaxPool-like operations.
    244 Status MaxPoolShape(shape_inference::InferenceContext* c);
    245 
    246 // Shape function for MaxPoolV2-like operations.
    247 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
    248 
    249 // Shape function for 3D Pooling operations.
    250 Status Pool3DShape(shape_inference::InferenceContext* c);
    251 
    252 // Shape function for use with ops whose output shapes are unknown.
    253 Status UnknownShape(shape_inference::InferenceContext* c);
    254 
    255 // Shape function for reduction operations.
    256 Status ReductionShape(shape_inference::InferenceContext* c);
    257 
    258 // Shape function for concat operations.
    259 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
    260 // from inputs
    261 // [1,num_inputs_to_concat] of the op.  Input 0 is the concat_dim input.
    262 Status ConcatShape(shape_inference::InferenceContext* c,
    263                    int num_inputs_to_concat);
    264 
    265 // Shape function for concat operations.
    266 Status ConcatV2Shape(shape_inference::InferenceContext* c);
    267 
    268 // Shape function for binary operators that broadcast their inputs.
    269 // Tested by ops/math_ops_test.cc.
    270 Status BroadcastBinaryOpShapeFn(InferenceContext* c);
    271 
    272 // Shape function for random operations.
    273 Status RandomShape(shape_inference::InferenceContext* c);
    274 
    275 // Validates the 3 component tensors of a sparse tensor have the proper
    276 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
    277 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
    278                             ShapeHandle values_shape, ShapeHandle shape_shape);
    279 
    280 // Shape function for ScatterNd update/add/sub/... operations.
    281 Status ScatterNdUpdateShape(InferenceContext* c);
    282 
    283 // Shape function for ops with an explicit "shape" attribute.
    284 Status ExplicitShape(InferenceContext* c);
    285 
    286 }  // namespace shape_inference
    287 
    288 }  // namespace tensorflow
    289 
    290 #endif  // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
    291