Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/common_shape_fns.h"
     16 #include "tensorflow/core/framework/attr_value.pb.h"
     17 
     18 namespace tensorflow {
     19 
     20 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
     21                                       int64 dilation_rate, int64 stride,
     22                                       Padding padding_type, int64* output_size,
     23                                       int64* padding_before,
     24                                       int64* padding_after) {
     25   if (stride <= 0) {
     26     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
     27   }
     28   if (dilation_rate < 1) {
     29     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
     30                                    dilation_rate);
     31   }
     32 
     33   // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
     34   int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
     35   switch (padding_type) {
     36     case Padding::VALID:
     37       *output_size = (input_size - effective_filter_size + stride) / stride;
     38       *padding_before = *padding_after = 0;
     39       break;
     40     case Padding::SAME:
     41       *output_size = (input_size + stride - 1) / stride;
     42       const int64 padding_needed =
     43           std::max(0LL, (*output_size - 1) * stride + effective_filter_size -
     44                             input_size);
     45       // For odd values of total padding, add more padding at the 'right'
     46       // side of the given dimension.
     47       *padding_before = padding_needed / 2;
     48       *padding_after = padding_needed - *padding_before;
     49       break;
     50   }
     51   if (*output_size < 0) {
     52     return errors::InvalidArgument(
     53         "Computed output size would be negative: ", *output_size,
     54         " [input_size: ", input_size,
     55         ", effective_filter_size: ", effective_filter_size,
     56         ", stride: ", stride, "]");
     57   }
     58   return Status::OK();
     59 }
     60 
     61 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
     62                                     int64 stride, Padding padding_type,
     63                                     int64* output_size, int64* padding_before,
     64                                     int64* padding_after) {
     65   return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
     66                                         /*dilation_rate=*/1, stride,
     67                                         padding_type, output_size,
     68                                         padding_before, padding_after);
     69 }
     70 
     71 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
     72                              Padding padding_type, int64* output_size,
     73                              int64* padding_size) {
     74   int64 padding_after_unused;
     75   return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
     76                                       padding_type, output_size, padding_size,
     77                                       &padding_after_unused);
     78 }
     79 
     80 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
     81                                int64 dilation_rate, int64 stride,
     82                                Padding padding_type, int64* output_size,
     83                                int64* padding_size) {
     84   int64 padding_after_unused;
     85   return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
     86                                         stride, padding_type, output_size,
     87                                         padding_size, &padding_after_unused);
     88 }
     89 
     90 Status Get3dOutputSize(const std::array<int64, 3>& input,
     91                        const std::array<int64, 3>& window,
     92                        const std::array<int64, 3>& strides,
     93                        Padding padding_type, std::array<int64, 3>* output_ptr,
     94                        std::array<int64, 3>* padding_ptr) {
     95   for (size_t i = 0; i < input.size(); ++i) {
     96     TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
     97                                              padding_type, &(*output_ptr)[i],
     98                                              &(*padding_ptr)[i]));
     99   }
    100   return Status::OK();
    101 }
    102 
    103 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
    104                          const std::array<int64, 3>& window,
    105                          const std::array<int64, 3>& dilations,
    106                          const std::array<int64, 3>& strides,
    107                          Padding padding_type, std::array<int64, 3>* output_ptr,
    108                          std::array<int64, 3>* padding_ptr) {
    109   for (size_t i = 0; i < input.size(); ++i) {
    110     TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
    111         input[i], window[i], dilations[i], strides[i], padding_type,
    112         &(*output_ptr)[i], &(*padding_ptr)[i]));
    113   }
    114   return Status::OK();
    115 }
    116 
    117 namespace shape_inference {
    118 
    119 // The V2 version computes windowed output size with arbitrary dilation_rate,
    120 // while the original version only handles the cases where dilation_rates equal
    121 // to 1.
    122 Status GetWindowedOutputSizeFromDimsV2(
    123     shape_inference::InferenceContext* c,
    124     shape_inference::DimensionHandle input_size,
    125     shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
    126     int64 stride, Padding padding_type,
    127     shape_inference::DimensionHandle* output_size) {
    128   if (stride <= 0) {
    129     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
    130   }
    131 
    132   if (dilation_rate < 1) {
    133     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
    134                                    dilation_rate);
    135   }
    136 
    137   // See also the parallel implementation in GetWindowedOutputSizeVerbose.
    138   switch (padding_type) {
    139     case Padding::VALID:
    140       if (dilation_rate > 1) {
    141         DimensionHandle window_size;
    142         TF_RETURN_IF_ERROR(
    143             c->Subtract(c->MakeDim(filter_size), 1, &window_size));
    144         TF_RETURN_IF_ERROR(
    145             c->Multiply(window_size, dilation_rate, &window_size));
    146         TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
    147         TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
    148       } else {
    149         TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
    150       }
    151       TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
    152       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
    153                                    /*evenly_divisible=*/false, output_size));
    154       break;
    155     case Padding::SAME:
    156       TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
    157       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
    158                                    /*evenly_divisible=*/false, output_size));
    159       break;
    160   }
    161   return Status::OK();
    162 }
    163 
    164 Status GetWindowedOutputSizeFromDims(
    165     shape_inference::InferenceContext* c,
    166     shape_inference::DimensionHandle input_size,
    167     shape_inference::DimensionOrConstant filter_size, int64 stride,
    168     Padding padding_type, shape_inference::DimensionHandle* output_size) {
    169   return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
    170                                          /*dilation_rate=*/1, stride,
    171                                          padding_type, output_size);
    172 }
    173 
    174 Status UnchangedShape(shape_inference::InferenceContext* c) {
    175   c->set_output(0, c->input(0));
    176   return Status::OK();
    177 }
    178 
    179 Status MatMulShape(shape_inference::InferenceContext* c) {
    180   ShapeHandle a;
    181   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
    182 
    183   ShapeHandle b;
    184   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
    185 
    186   bool transpose_a, transpose_b;
    187   TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
    188   TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
    189   DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
    190   DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
    191 
    192   // Validate that the inner shapes are compatible.
    193   DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
    194   DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
    195   DimensionHandle merged;
    196   TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
    197 
    198   c->set_output(0, c->Matrix(output_rows, output_cols));
    199   return Status::OK();
    200 }
    201 
    202 Status BiasAddShape(shape_inference::InferenceContext* c) {
    203   ShapeHandle input_shape;
    204 
    205   // Fetch the data_format attribute, which may not exist.
    206   string data_format;
    207   Status s = c->GetAttr("data_format", &data_format);
    208 
    209   if (s.ok() && data_format == "NCHW") {
    210     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
    211   } else {
    212     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
    213   }
    214 
    215   ShapeHandle bias_shape;
    216   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
    217   DimensionHandle bias_dim = c->Dim(bias_shape, 0);
    218 
    219   // If rank unknown, return unknown shape.
    220   if (!c->RankKnown(input_shape)) {
    221     c->set_output(0, c->UnknownShape());
    222     return Status::OK();
    223   }
    224 
    225   // Output has the same shape as the input, and matches the length of
    226   // the bias in its bias dimension.
    227   ShapeHandle output_shape;
    228   if (s.ok() && data_format == "NCHW") {
    229     // Merge the length of bias_shape into the third to last dimension
    230     ShapeHandle first;
    231     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first));
    232 
    233     ShapeHandle last;
    234     TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last));
    235 
    236     DimensionHandle input_bias_dim = c->Dim(input_shape, -3);
    237     DimensionHandle merged_bias_dim;
    238     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
    239     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
    240 
    241     ShapeHandle temp;
    242     TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
    243     TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
    244   } else {
    245     ShapeHandle all_but_bias;
    246     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
    247 
    248     DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
    249     DimensionHandle merged_bias_dim;
    250     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
    251 
    252     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
    253     TF_RETURN_IF_ERROR(
    254         c->Concatenate(all_but_bias, merged_bias, &output_shape));
    255   }
    256 
    257   c->set_output(0, output_shape);
    258   return Status::OK();
    259 }
    260 
    261 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
    262   ShapeHandle input_shape;
    263   // Fetch the data_format attribute, which may not exist.
    264   string data_format;
    265   Status s = c->GetAttr("data_format", &data_format);
    266 
    267   if (s.ok() && data_format == "NCHW") {
    268     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
    269     c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
    270   } else {
    271     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
    272     c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
    273   }
    274 
    275   return Status::OK();
    276 }
    277 
    278 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
    279                                      const ShapeHandle shape_handle,
    280                                      const string& tensor_name,
    281                                      shape_inference::InferenceContext* c) {
    282   if (tensor_format == FORMAT_NCHW_VECT_C) {
    283     // Check that the vect dim has size 4.
    284     const int num_dims = c->Rank(shape_handle);
    285     DimensionHandle vect_dim = c->Dim(
    286         shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
    287     DimensionHandle unused_vect_dim;
    288     TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
    289   }
    290 
    291   return Status::OK();
    292 }
    293 
    294 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
    295                            const std::vector<DimensionOrConstant>& spatial,
    296                            DimensionOrConstant C, ShapeHandle* out,
    297                            shape_inference::InferenceContext* context) {
    298   const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
    299   std::vector<DimensionHandle> dims_actual(num_dims);
    300   dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
    301   int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
    302   dims_actual[outer_c_index] = context->MakeDim(C);
    303   if (format == FORMAT_NCHW_VECT_C) {
    304     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
    305         context->MakeDim(4);
    306   }
    307   for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
    308     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
    309         context->MakeDim(spatial[spatial_dim]);
    310   }
    311   *out = context->MakeShape(dims_actual);
    312   return Status::OK();
    313 }
    314 
    315 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
    316                            DimensionHandle* batch_dim,
    317                            gtl::MutableArraySlice<DimensionHandle> spatial_dims,
    318                            DimensionHandle* filter_dim,
    319                            InferenceContext* context) {
    320   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
    321   // Batch.
    322   *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
    323   // Spatial.
    324   for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
    325        ++spatial_dim_index) {
    326     spatial_dims[spatial_dim_index] = context->Dim(
    327         shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
    328   }
    329   // Channel.
    330   *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
    331   if (format == FORMAT_NCHW_VECT_C) {
    332     TF_RETURN_IF_ERROR(context->Multiply(
    333         *filter_dim,
    334         context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
    335         filter_dim));
    336   }
    337   return Status::OK();
    338 }
    339 
    340 Status ShapeFromDimensions(DimensionHandle batch_dim,
    341                            gtl::ArraySlice<DimensionHandle> spatial_dims,
    342                            DimensionHandle filter_dim, TensorFormat format,
    343                            InferenceContext* context, ShapeHandle* shape) {
    344   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
    345   std::vector<DimensionHandle> out_dims(rank);
    346 
    347   // Batch.
    348   out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
    349   // Spatial.
    350   for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
    351        ++spatial_dim_index) {
    352     out_dims[tensorflow::GetTensorSpatialDimIndex(
    353         rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
    354   }
    355   // Channel.
    356   if (format == tensorflow::FORMAT_NCHW_VECT_C) {
    357     // When format is NCHW_VECT_C, factor the feature map count
    358     // into the outer feature count and the inner feature count (=4).
    359     TF_RETURN_IF_ERROR(context->Divide(
    360         filter_dim, 4, /*evenly_divisible=*/true,
    361         &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
    362     out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
    363   } else {
    364     out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
    365   }
    366 
    367   *shape = context->MakeShape(out_dims);
    368   return tensorflow::Status::OK();
    369 }
    370 
    371 Status Conv2DShape(shape_inference::InferenceContext* c) {
    372   string data_format_str, filter_format_str;
    373   if (!c->GetAttr("data_format", &data_format_str).ok()) {
    374     data_format_str = "NHWC";
    375   }
    376   if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
    377     filter_format_str = "HWIO";
    378   }
    379 
    380   TensorFormat data_format;
    381   if (!FormatFromString(data_format_str, &data_format)) {
    382     return errors::InvalidArgument("Invalid data format string: ",
    383                                    data_format_str);
    384   }
    385   FilterTensorFormat filter_format;
    386   if (!FilterFormatFromString(filter_format_str, &filter_format)) {
    387     return errors::InvalidArgument("Invalid filter format string: ",
    388                                    filter_format_str);
    389   }
    390 
    391   constexpr int num_spatial_dims = 2;
    392   const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
    393   ShapeHandle conv_input_shape;
    394   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
    395   TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
    396       data_format, conv_input_shape, "conv_input", c));
    397 
    398   // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
    399   ShapeHandle filter_shape;
    400   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
    401   TF_RETURN_IF_ERROR(
    402       CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
    403 
    404   std::vector<int32> dilations;
    405   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
    406 
    407   if (dilations.size() != 4) {
    408     return errors::InvalidArgument(
    409         "Conv2D requires the dilation attribute to contain 4 values, but got: ",
    410         dilations.size());
    411   }
    412 
    413   std::vector<int32> strides;
    414   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    415 
    416   // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
    417   if (strides.size() != 4) {
    418     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
    419                                    " requires the stride attribute to contain"
    420                                    " 4 values, but got: ",
    421                                    strides.size());
    422   }
    423 
    424   const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
    425   const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
    426   const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
    427   const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
    428 
    429   DimensionHandle batch_size_dim;
    430   DimensionHandle input_depth_dim;
    431   gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
    432   TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
    433                                          &batch_size_dim, &input_spatial_dims,
    434                                          &input_depth_dim, c));
    435 
    436   DimensionHandle output_depth_dim = c->Dim(
    437       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
    438   DimensionHandle filter_rows_dim = c->Dim(
    439       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
    440   DimensionHandle filter_cols_dim = c->Dim(
    441       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
    442   DimensionHandle filter_input_depth_dim;
    443   if (filter_format == FORMAT_OIHW_VECT_I) {
    444     TF_RETURN_IF_ERROR(c->Multiply(
    445         c->Dim(filter_shape,
    446                GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
    447         c->Dim(filter_shape,
    448                GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
    449         &filter_input_depth_dim));
    450   } else {
    451     filter_input_depth_dim = c->Dim(
    452         filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
    453   }
    454 
    455   // Check that the input tensor and the filter tensor agree on the input
    456   // channel count.
    457   DimensionHandle unused;
    458   TF_RETURN_IF_ERROR(
    459       c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
    460 
    461   Padding padding;
    462   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    463 
    464   DimensionHandle output_rows, output_cols;
    465   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
    466       c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
    467       padding, &output_rows));
    468   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
    469       c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
    470       padding, &output_cols));
    471 
    472   ShapeHandle output_shape;
    473   TF_RETURN_IF_ERROR(
    474       ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
    475                           output_depth_dim, data_format, c, &output_shape));
    476   c->set_output(0, output_shape);
    477   return Status::OK();
    478 }
    479 
    480 // TODO(mjanusz): Unify all conv/pooling shape functions.
    481 Status Conv3DShape(shape_inference::InferenceContext* c) {
    482   ShapeHandle input_shape;
    483   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
    484   ShapeHandle filter_shape;
    485   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
    486 
    487   string data_format;
    488   Status s = c->GetAttr("data_format", &data_format);
    489 
    490   std::vector<int32> strides;
    491   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    492   if (strides.size() != 5) {
    493     return errors::InvalidArgument(
    494         "Conv3D requires the stride attribute to contain 5 values, but got: ",
    495         strides.size());
    496   }
    497 
    498   int32 stride_planes, stride_rows, stride_cols;
    499   if (s.ok() && data_format == "NCDHW") {
    500     // Convert input_shape to NDHWC.
    501     auto dim = [&](char dimension) {
    502       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
    503     };
    504     input_shape =
    505         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
    506     stride_planes = strides[2];
    507     stride_cols = strides[3];
    508     stride_rows = strides[4];
    509   } else {
    510     stride_planes = strides[1];
    511     stride_rows = strides[2];
    512     stride_cols = strides[3];
    513   }
    514 
    515   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
    516   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
    517   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
    518   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
    519 
    520   DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
    521   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
    522   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
    523   DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
    524 
    525   DimensionHandle unused;
    526   TF_RETURN_IF_ERROR(
    527       c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
    528 
    529   Padding padding;
    530   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    531   DimensionHandle output_planes, output_rows, output_cols;
    532 
    533   TF_RETURN_IF_ERROR(
    534       GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim,
    535                                     stride_planes, padding, &output_planes));
    536   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    537       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
    538   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    539       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
    540 
    541   ShapeHandle output_shape;
    542   if (data_format == "NCDHW") {
    543     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
    544                                  output_planes, output_rows, output_cols});
    545   } else {
    546     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
    547                                  output_cols, output_depth_dim});
    548   }
    549   c->set_output(0, output_shape);
    550   return Status::OK();
    551 }
    552 
    553 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
    554   ShapeHandle input_shape;
    555   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
    556   ShapeHandle filter_shape;
    557   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
    558 
    559   std::vector<int32> strides;
    560   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    561 
    562   if (strides.size() != 4) {
    563     return errors::InvalidArgument(
    564         "DepthwiseConv2D requires the stride attribute to contain 4 values, "
    565         "but got: ",
    566         strides.size());
    567   }
    568 
    569   string data_format;
    570   Status s = c->GetAttr("data_format", &data_format);
    571   int32 stride_rows;
    572   int32 stride_cols;
    573   if (s.ok() && data_format == "NCHW") {
    574     // Canonicalize input shape to NHWC so the shape inference code below can
    575     // process it.
    576     input_shape =
    577         c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
    578                        c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
    579     stride_rows = strides[2];
    580     stride_cols = strides[3];
    581   } else {
    582     stride_rows = strides[1];
    583     stride_cols = strides[2];
    584   }
    585 
    586   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
    587   DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
    588   DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
    589 
    590   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
    591   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
    592   DimensionHandle input_depth = c->Dim(filter_shape, 2);
    593   DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
    594 
    595   // Check that the input depths are compatible.
    596   TF_RETURN_IF_ERROR(
    597       c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
    598 
    599   DimensionHandle output_depth;
    600   TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
    601 
    602   Padding padding;
    603   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    604 
    605   // TODO(mrry,shlens): Raise an error if the stride would cause
    606   // information in the input to be ignored. This will require a change
    607   // in the kernel implementation.
    608   DimensionHandle output_rows, output_cols;
    609 
    610   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    611       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
    612   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    613       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
    614 
    615   ShapeHandle output_shape;
    616   if (data_format == "NCHW") {
    617     output_shape =
    618         c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
    619   } else {
    620     output_shape =
    621         c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
    622   }
    623   c->set_output(0, output_shape);
    624   return Status::OK();
    625 }
    626 
    627 Status AvgPoolShape(shape_inference::InferenceContext* c) {
    628   string data_format_str;
    629   TensorFormat data_format;
    630   Status s = c->GetAttr("data_format", &data_format_str);
    631   if (s.ok()) {
    632     FormatFromString(data_format_str, &data_format);
    633   } else {
    634     data_format = FORMAT_NHWC;
    635   }
    636 
    637   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
    638   ShapeHandle input_shape;
    639   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
    640 
    641   TF_RETURN_IF_ERROR(
    642       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
    643 
    644   std::vector<int32> strides;
    645   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    646   if (strides.size() != 4) {
    647     return errors::InvalidArgument(
    648         "AvgPool requires the stride attribute to contain 4 values, but got: ",
    649         strides.size());
    650   }
    651 
    652   std::vector<int32> kernel_sizes;
    653   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
    654   if (kernel_sizes.size() != 4) {
    655     return errors::InvalidArgument(
    656         "AvgPool requires the ksize attribute to contain 4 values, but got: ",
    657         kernel_sizes.size());
    658   }
    659 
    660   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
    661   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
    662   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
    663   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
    664 
    665   constexpr int num_spatial_dims = 2;
    666   DimensionHandle batch_size_dim = c->Dim(
    667       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
    668   DimensionHandle in_rows_dim = c->Dim(
    669       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
    670   DimensionHandle in_cols_dim = c->Dim(
    671       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
    672   DimensionHandle depth_dim = c->Dim(
    673       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
    674 
    675   Padding padding;
    676   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    677 
    678   // TODO(mrry,shlens): Raise an error if the stride would cause
    679   // information in the input to be ignored. This will require a change
    680   // in the kernel implementation.
    681 
    682   DimensionHandle output_rows, output_cols;
    683   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    684       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
    685   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    686       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
    687 
    688   ShapeHandle output_shape;
    689   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
    690                                          {output_rows, output_cols}, depth_dim,
    691                                          &output_shape, c));
    692   c->set_output(0, output_shape);
    693   return Status::OK();
    694 }
    695 
    696 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
    697   ShapeHandle x;
    698   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
    699 
    700   bool is_training;
    701   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
    702   int number_inputs = (is_training) ? 3 : 5;
    703   string data_format;
    704   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
    705   DimensionHandle channel_dim =
    706       (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
    707 
    708   // covers scale, offset, and if is_training is false, mean, variance
    709   for (int i = 1; i < number_inputs; ++i) {
    710     ShapeHandle vec;
    711     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
    712     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
    713   }
    714 
    715   ShapeHandle y;
    716   if (data_format == "NHWC") {
    717     TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
    718   } else {
    719     TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
    720   }
    721   c->set_output(0, y);
    722   ShapeHandle vector_shape = c->Vector(channel_dim);
    723   c->set_output(1, vector_shape);
    724   c->set_output(2, vector_shape);
    725   c->set_output(3, vector_shape);
    726   c->set_output(4, vector_shape);
    727   return Status::OK();
    728 }
    729 
    730 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
    731   ShapeHandle y_backprop;
    732   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
    733   ShapeHandle x;
    734   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
    735 
    736   bool is_training;
    737   string data_format;
    738   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
    739   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
    740   DimensionHandle channel_dim =
    741       (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1);
    742   if (data_format == "NHWC") {
    743     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
    744   } else {
    745     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
    746   }
    747 
    748   // covers scale, mean (reserve_space_1), variance (reserve_space_2)
    749   for (int i = 2; i < 5; ++i) {
    750     ShapeHandle vec;
    751     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
    752     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
    753   }
    754 
    755   ShapeHandle x_backprop;
    756   if (data_format == "NHWC") {
    757     TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
    758   } else {
    759     TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
    760   }
    761   c->set_output(0, x_backprop);
    762   c->set_output(1, c->Vector(channel_dim));
    763   c->set_output(2, c->Vector(channel_dim));
    764   // Set the correct shapes for reserve_spaces
    765   // so that gradients can be performed when
    766   // the op is in a symbolic condition.
    767   if (is_training) {
    768     c->set_output(3, c->Vector(0));
    769     c->set_output(4, c->Vector(0));
    770   } else {
    771     c->set_output(3, c->Vector(channel_dim));
    772     c->set_output(4, c->Vector(channel_dim));
    773   }
    774   return Status::OK();
    775 }
    776 
    777 Status MaxPoolShape(shape_inference::InferenceContext* c) {
    778   string data_format_str;
    779   TensorFormat data_format;
    780   Status s = c->GetAttr("data_format", &data_format_str);
    781   if (s.ok()) {
    782     FormatFromString(data_format_str, &data_format);
    783   } else {
    784     data_format = FORMAT_NHWC;
    785   }
    786 
    787   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
    788   ShapeHandle input_shape;
    789   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
    790 
    791   TF_RETURN_IF_ERROR(
    792       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
    793 
    794   std::vector<int32> strides;
    795   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    796   if (strides.size() != 4) {
    797     return errors::InvalidArgument(
    798         "MaxPool requires the stride attribute to contain 4 values, but got: ",
    799         strides.size());
    800   }
    801 
    802   std::vector<int32> kernel_sizes;
    803   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
    804   if (kernel_sizes.size() != 4) {
    805     return errors::InvalidArgument(
    806         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
    807         kernel_sizes.size());
    808   }
    809 
    810   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
    811   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
    812   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
    813   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
    814   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
    815   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
    816 
    817   constexpr int num_spatial_dims = 2;
    818   DimensionHandle batch_size_dim = c->Dim(
    819       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
    820   DimensionHandle in_rows_dim = c->Dim(
    821       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
    822   DimensionHandle in_cols_dim = c->Dim(
    823       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
    824   DimensionHandle in_depth_dim = c->Dim(
    825       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
    826 
    827   Padding padding;
    828   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    829 
    830   ShapeHandle output_shape;
    831   DimensionHandle output_rows, output_cols, output_depth;
    832   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    833       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
    834   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    835       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
    836   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    837       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
    838 
    839   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
    840                                          {output_rows, output_cols},
    841                                          output_depth, &output_shape, c));
    842 
    843   c->set_output(0, output_shape);
    844   return Status::OK();
    845 }
    846 
    847 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
    848   string data_format_str;
    849   TensorFormat data_format;
    850   Status s = c->GetAttr("data_format", &data_format_str);
    851   if (s.ok()) {
    852     FormatFromString(data_format_str, &data_format);
    853   } else {
    854     data_format = FORMAT_NHWC;
    855   }
    856 
    857   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
    858   ShapeHandle input_shape;
    859   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
    860 
    861   TF_RETURN_IF_ERROR(
    862       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
    863 
    864   std::vector<int32> kernel_sizes;
    865   std::vector<int32> strides;
    866 
    867   if (c->num_inputs() + 2 == num_inputs) {
    868     TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
    869 
    870     TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    871   } else {
    872     // Verify shape of ksize and strides input.
    873     ShapeHandle size;
    874     DimensionHandle unused;
    875     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
    876     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
    877     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
    878     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
    879 
    880     const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
    881     if (kernel_sizes_tensor == nullptr) {
    882       c->set_output(0, c->UnknownShape());
    883       return Status::OK();
    884     }
    885     kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
    886     auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
    887     std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
    888                 kernel_sizes.begin());
    889 
    890     const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
    891     if (strides_tensor == nullptr) {
    892       c->set_output(0, c->UnknownShape());
    893       return Status::OK();
    894     }
    895     strides.resize(strides_tensor->shape().num_elements());
    896     auto strides_vec = strides_tensor->flat<int32>();
    897     std::copy_n(&strides_vec(0), strides.size(), strides.begin());
    898   }
    899 
    900   if (strides.size() != 4) {
    901     return errors::InvalidArgument(
    902         "MaxPool requires the stride attribute to contain 4 values, but "
    903         "got: ",
    904         strides.size());
    905   }
    906   if (kernel_sizes.size() != 4) {
    907     return errors::InvalidArgument(
    908         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
    909         kernel_sizes.size());
    910   }
    911 
    912   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
    913   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
    914   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
    915   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
    916   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
    917   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
    918 
    919   constexpr int num_spatial_dims = 2;
    920   DimensionHandle batch_size_dim = c->Dim(
    921       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
    922   DimensionHandle in_rows_dim = c->Dim(
    923       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
    924   DimensionHandle in_cols_dim = c->Dim(
    925       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
    926   DimensionHandle in_depth_dim = c->Dim(
    927       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
    928 
    929   Padding padding;
    930   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    931 
    932   ShapeHandle output_shape;
    933   DimensionHandle output_rows, output_cols, output_depth;
    934   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    935       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
    936   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    937       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
    938   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    939       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
    940 
    941   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
    942                                          {output_rows, output_cols},
    943                                          output_depth, &output_shape, c));
    944 
    945   c->set_output(0, output_shape);
    946   return Status::OK();
    947 }
    948 
    949 Status Pool3DShape(shape_inference::InferenceContext* c) {
    950   ShapeHandle input_shape;
    951   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
    952 
    953   string data_format;
    954   Status s = c->GetAttr("data_format", &data_format);
    955 
    956   std::vector<int32> strides;
    957   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    958   if (strides.size() != 5) {
    959     return errors::InvalidArgument(
    960         "Pool3D ops require the stride attribute to contain 5 values, but "
    961         "got: ",
    962         strides.size());
    963   }
    964 
    965   std::vector<int32> kernel_sizes;
    966   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
    967   if (kernel_sizes.size() != 5) {
    968     return errors::InvalidArgument(
    969         "Pool3D requires the ksize attribute to contain 5 values, but got: ",
    970         kernel_sizes.size());
    971   }
    972 
    973   int32 stride_planes, stride_rows, stride_cols;
    974   int32 kernel_planes, kernel_rows, kernel_cols;
    975 
    976   if (s.ok() && data_format == "NCDHW") {
    977     // Convert input_shape to NDHWC.
    978     auto dim = [&](char dimension) {
    979       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
    980     };
    981     input_shape =
    982         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
    983     stride_planes = strides[2];
    984     stride_rows = strides[3];
    985     stride_cols = strides[4];
    986     kernel_planes = kernel_sizes[2];
    987     kernel_rows = kernel_sizes[3];
    988     kernel_cols = kernel_sizes[4];
    989   } else {
    990     stride_planes = strides[1];
    991     stride_rows = strides[2];
    992     stride_cols = strides[3];
    993     kernel_planes = kernel_sizes[1];
    994     kernel_rows = kernel_sizes[2];
    995     kernel_cols = kernel_sizes[3];
    996   }
    997 
    998   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
    999   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
   1000   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
   1001   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
   1002   DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
   1003 
   1004   Padding padding;
   1005   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
   1006 
   1007   // TODO(mrry,shlens): Raise an error if the stride would cause
   1008   // information in the input to be ignored. This will require a change
   1009   // in the kernel implementation.
   1010   DimensionHandle output_planes, output_rows, output_cols;
   1011   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
   1012       c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
   1013   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
   1014       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
   1015   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
   1016       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
   1017 
   1018   ShapeHandle output_shape;
   1019   if (data_format == "NCDHW") {
   1020     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
   1021                                  output_planes, output_rows, output_cols});
   1022   } else {
   1023     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
   1024                                  output_cols, output_depth_dim});
   1025   }
   1026 
   1027   c->set_output(0, output_shape);
   1028   return Status::OK();
   1029 }
   1030 
   1031 Status UnknownShape(shape_inference::InferenceContext* c) {
   1032   for (int i = 0; i < c->num_outputs(); ++i) {
   1033     c->set_output(i, c->UnknownShape());
   1034   }
   1035   return Status::OK();
   1036 }
   1037 
   1038 template <typename T>
   1039 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
   1040                             const int32 input_rank,
   1041                             std::set<int64>& true_indices) {
   1042   auto reduction_indices = reduction_indices_t->flat<T>();
   1043   for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
   1044     const T reduction_index = reduction_indices(i);
   1045     if (reduction_index < -input_rank || reduction_index >= input_rank) {
   1046       return errors::InvalidArgument("Invalid reduction dimension ",
   1047                                      reduction_index, " for input with ",
   1048                                      input_rank, " dimensions.");
   1049     }
   1050 
   1051     auto wrapped_index = reduction_index;
   1052     if (wrapped_index < 0) {
   1053       wrapped_index += input_rank;
   1054     }
   1055 
   1056     true_indices.insert(wrapped_index);
   1057   }
   1058   return Status::OK();
   1059 }
   1060 
   1061 Status ReductionShape(InferenceContext* c) {
   1062   ShapeHandle input = c->input(0);
   1063 
   1064   ShapeHandle indices;
   1065   // Older versions of TensorFlow accidentally allowed higher rank tensors like
   1066   // [[1,2]] or [[1],[2]] to represent axis=[1,2].
   1067   if (c->graph_def_version() < 21) {
   1068     indices = c->input(1);
   1069   } else {
   1070     TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
   1071   }
   1072 
   1073   bool keep_dims;
   1074   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
   1075 
   1076   const Tensor* reduction_indices_t = c->input_tensor(1);
   1077   if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
   1078     // If we do not have the reduction values at runtime, or the
   1079     // rank of the input, we don't know the output shape.
   1080 
   1081     if (keep_dims && c->RankKnown(input)) {
   1082       // output rank matches input input if <keep_dims>.
   1083       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
   1084       return Status::OK();
   1085     } else {
   1086       return shape_inference::UnknownShape(c);
   1087     }
   1088   }
   1089 
   1090   const int32 input_rank = c->Rank(input);
   1091   std::set<int64> true_indices;
   1092   if (reduction_indices_t->dtype() == DataType::DT_INT32) {
   1093     TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
   1094                                                    input_rank, true_indices));
   1095   } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
   1096     TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
   1097                                                    input_rank, true_indices));
   1098   } else {
   1099     return errors::InvalidArgument(
   1100         "reduction_indices can only be int32 or int64");
   1101   }
   1102 
   1103   std::vector<DimensionHandle> dims;
   1104   for (int i = 0; i < input_rank; ++i) {
   1105     if (true_indices.count(i) > 0) {
   1106       if (keep_dims) {
   1107         dims.emplace_back(c->MakeDim(1));
   1108       }
   1109     } else {
   1110       dims.emplace_back(c->Dim(input, i));
   1111     }
   1112   }
   1113 
   1114   c->set_output(0, c->MakeShape(dims));
   1115   return Status::OK();
   1116 }
   1117 
   1118 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
   1119                          int end_value_index, int dim_index) {
   1120   ShapeHandle unused;
   1121   TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
   1122   const Tensor* concat_dim_t = c->input_tensor(dim_index);
   1123   if (concat_dim_t == nullptr) {
   1124     // Return an unknown shape with same rank as inputs, or an unknown rank
   1125     // if no input's rank is known.
   1126 
   1127     // Find rank.
   1128     int32 rank = InferenceContext::kUnknownRank;
   1129     for (int i = start_value_index; i < end_value_index; ++i) {
   1130       if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
   1131       if (rank != InferenceContext::kUnknownRank) {
   1132         break;
   1133       }
   1134     }
   1135     if (rank == InferenceContext::kUnknownRank) {
   1136       c->set_output(0, c->UnknownShape());
   1137       return Status::OK();
   1138     } else if (rank == 0) {
   1139       return errors::InvalidArgument(
   1140           "Can't concatenate scalars (use tf.stack instead)");
   1141     } else {
   1142       for (int i = start_value_index; i < end_value_index; ++i) {
   1143         // Check that all the inputs are of the correct rank.
   1144         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
   1145       }
   1146     }
   1147     // Build result of <rank> different unknown dims.
   1148     std::vector<DimensionHandle> dims;
   1149     dims.reserve(rank);
   1150     for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
   1151     c->set_output(0, c->MakeShape(dims));
   1152     return Status::OK();
   1153   }
   1154 
   1155   // Merge all the non-concat dims, and sum the concat dim to make an output
   1156   // shape.
   1157   const int32 concat_dim = concat_dim_t->scalar<int32>()();
   1158 
   1159   // Minimum required number of dimensions.
   1160   const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
   1161 
   1162   ShapeHandle output_before;
   1163   ShapeHandle output_after;
   1164 
   1165   ShapeHandle input = c->input(end_value_index - 1);
   1166   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
   1167   TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
   1168   DimensionHandle output_middle = c->Dim(input, concat_dim);
   1169   if (concat_dim == -1) {
   1170     output_after = c->Scalar();  // no dimensions.
   1171   } else {
   1172     TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
   1173   }
   1174 
   1175   for (int i = end_value_index - 2; i >= start_value_index; --i) {
   1176     ShapeHandle before;
   1177     ShapeHandle after;
   1178     input = c->input(i);
   1179     TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
   1180     TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
   1181     DimensionHandle middle = c->Dim(input, concat_dim);
   1182     if (concat_dim == -1) {
   1183       after = c->Scalar();
   1184     } else {
   1185       TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
   1186     }
   1187 
   1188     TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
   1189     TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
   1190     TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
   1191   }
   1192 
   1193   ShapeHandle s;
   1194   TF_RETURN_IF_ERROR(
   1195       c->Concatenate(output_before, c->Vector(output_middle), &s));
   1196   TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
   1197   c->set_output(0, s);
   1198   return Status::OK();
   1199 }
   1200 
   1201 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
   1202   return ConcatShapeHelper(c, 1 /* start_value_index */,
   1203                            1 + num_inputs_to_concat /* end_value_index */,
   1204                            0 /* dim_index */);
   1205 }
   1206 
   1207 Status ConcatV2Shape(InferenceContext* c) {
   1208   return ConcatShapeHelper(c, 0 /* start_value_index */,
   1209                            c->num_inputs() - 1 /* end_value_index */,
   1210                            c->num_inputs() - 1 /* dim_index */);
   1211 }
   1212 
   1213 Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
   1214   ShapeHandle shape_x = c->input(0);
   1215   ShapeHandle shape_y = c->input(1);
   1216   if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
   1217     c->set_output(0, c->UnknownShape());
   1218     return Status::OK();
   1219   }
   1220   const int32 rank_x = c->Rank(shape_x);
   1221   const int32 rank_y = c->Rank(shape_y);
   1222   const int32 rank_out = std::max(rank_x, rank_y);
   1223 
   1224   // To compute the broadcast dimensions, we zip together shape_x and shape_y
   1225   // and
   1226   // pad with 1 to make them the same length.
   1227   std::vector<DimensionHandle> dims;
   1228   DimensionHandle dim_one;
   1229   if (rank_x != rank_y) dim_one = c->MakeDim(1);
   1230   for (int i = 0; i < rank_out; ++i) {
   1231     const auto dim_x = i < (rank_out - rank_x)
   1232                            ? dim_one
   1233                            : c->Dim(shape_x, i - (rank_out - rank_x));
   1234     const bool dim_y_is_one = (i < (rank_out - rank_y));
   1235     const auto dim_y =
   1236         dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
   1237     if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
   1238       // One or both dimensions is unknown.
   1239       //
   1240       // - If either dimension is greater than 1, we assume that the program is
   1241       // correct, and the other dimension will be broadcast to match it.
   1242       // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
   1243       // in C++ op code, we must still assert that the unknown dim is either 1
   1244       // or the same as the known dim.
   1245       // - If either dimension is 1, the other dimension is the output.
   1246       if (c->Value(dim_x) > 1) {
   1247         dims.push_back(dim_x);
   1248       } else if (c->Value(dim_y) > 1) {
   1249         dims.push_back(dim_y);
   1250       } else if (c->Value(dim_x) == 1) {
   1251         dims.push_back(dim_y);
   1252       } else if (c->Value(dim_y) == 1) {
   1253         dims.push_back(dim_x);
   1254       } else if (dim_y.SameHandle(dim_x)) {
   1255         dims.push_back(dim_x);
   1256       } else {
   1257         dims.push_back(c->UnknownDim());
   1258       }
   1259     } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
   1260       if (c->Value(dim_x) == 1 && !dim_y_is_one) {
   1261         // We will broadcast dim_x to dim_y.
   1262         dims.push_back(dim_y);
   1263       } else {
   1264         DCHECK_EQ(c->Value(dim_y), 1);
   1265         // We will broadcast dim_y to dim_x.
   1266         dims.push_back(dim_x);
   1267       }
   1268     } else {
   1269       DimensionHandle dim;
   1270       TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
   1271       dims.push_back(dim);
   1272     }
   1273   }
   1274 
   1275   c->set_output(0, c->MakeShape(dims));
   1276   return Status::OK();
   1277 }
   1278 
   1279 Status RandomShape(shape_inference::InferenceContext* c) {
   1280   shape_inference::ShapeHandle out;
   1281   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
   1282   c->set_output(0, out);
   1283   return Status::OK();
   1284 }
   1285 
   1286 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
   1287                             ShapeHandle values_shape, ShapeHandle shape_shape) {
   1288   // Validate ranks.
   1289   ShapeHandle unused_shape;
   1290   TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
   1291   TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
   1292   TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
   1293 
   1294   // Number of elements in indices and values must match.
   1295   DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
   1296   if (c->ValueKnown(num_index_elements_dim)) {
   1297     DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
   1298     if (c->ValueKnown(num_values_elements_dim)) {
   1299       int64 num_index_elements = c->Value(num_index_elements_dim);
   1300       int64 num_values_elements = c->Value(num_values_elements_dim);
   1301       if (num_index_elements != num_values_elements) {
   1302         return errors::InvalidArgument("Number of elements in index (",
   1303                                        num_index_elements, ") and values (",
   1304                                        num_values_elements, ") do not match.");
   1305       }
   1306     }
   1307   }
   1308 
   1309   // Rank embedded in indices must match shape.
   1310   DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
   1311   if (c->ValueKnown(index_rank_dim)) {
   1312     DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
   1313     if (c->ValueKnown(shape_rank_dim)) {
   1314       int64 index_rank = c->Value(index_rank_dim);
   1315       int32 shape_rank = c->Value(shape_rank_dim);
   1316       if (index_rank != shape_rank) {
   1317         return errors::InvalidArgument("Index rank (", index_rank,
   1318                                        ") and shape rank (", shape_rank,
   1319                                        ") do not match.");
   1320       }
   1321     }
   1322   }
   1323 
   1324   return Status::OK();
   1325 }
   1326 
   1327 Status ScatterNdUpdateShape(InferenceContext* c) {
   1328   ShapeHandle input_shape = c->input(0);
   1329   if (c->input_handle_shapes_and_types(0) != nullptr) {
   1330     input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
   1331   }
   1332   ShapeHandle indices_shape;
   1333   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
   1334   ShapeHandle updates_shape;
   1335   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
   1336 
   1337   if (c->Value(c->NumElements(input_shape)) == 0 &&
   1338       (c->Value(c->NumElements(indices_shape)) > 0 ||
   1339        c->Value(c->NumElements(updates_shape)) > 0)) {
   1340     return errors::InvalidArgument(
   1341         "Indices and updates specified for empty output shape");
   1342   }
   1343 
   1344   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
   1345     const int64 num_outer_dims = c->Rank(indices_shape) - 1;
   1346     const DimensionHandle index_size = c->Dim(indices_shape, -1);
   1347 
   1348     // We can only do more validation if the last dimension of indices
   1349     // is a known value.
   1350     if (c->ValueKnown(index_size)) {
   1351       const int64 ix = c->Value(index_size);
   1352       ShapeHandle unused;
   1353       ShapeHandle prefix_indices;
   1354       TF_RETURN_IF_ERROR(
   1355           c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
   1356       ShapeHandle prefix_updates;
   1357       TF_RETURN_IF_ERROR(
   1358           c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
   1359 
   1360       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
   1361       if (!s.ok()) {
   1362         return errors::InvalidArgument(
   1363             "The outer ", num_outer_dims,
   1364             " dimensions of indices.shape=", c->DebugString(indices_shape),
   1365             " must match the outer ", num_outer_dims,
   1366             " dimensions of updates.shape=", c->DebugString(updates_shape),
   1367             ": ", s.error_message());
   1368       }
   1369 
   1370       ShapeHandle input_suffix;
   1371       TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
   1372       ShapeHandle suffix_updates;
   1373       TF_RETURN_IF_ERROR(
   1374           c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
   1375       s = c->Merge(input_suffix, suffix_updates, &unused);
   1376       if (!s.ok()) {
   1377         return errors::InvalidArgument(
   1378             "The inner ", c->Rank(input_shape) - ix,
   1379             " dimensions of input.shape=", c->DebugString(input_shape),
   1380             " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
   1381             " dimensions of updates.shape=", c->DebugString(updates_shape),
   1382             ": ", s.error_message());
   1383       }
   1384     }
   1385   }
   1386 
   1387   if (c->input_handle_shapes_and_types(0) == nullptr) {
   1388     c->set_output(0, input_shape);
   1389   }
   1390   return Status::OK();
   1391 }
   1392 
   1393 Status ExplicitShape(InferenceContext* c) {
   1394   PartialTensorShape shape;
   1395   TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
   1396   ShapeHandle output_shape;
   1397   TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
   1398   c->set_output(0, output_shape);
   1399   return Status::OK();
   1400 }
   1401 
   1402 }  // namespace shape_inference
   1403 
   1404 }  // namespace tensorflow
   1405