Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_
     17 #define TENSORFLOW_UTIL_TENSOR_FORMAT_H_
     18 
     19 #include <array>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 // Tensor format for input/output activations used in convolution operations.
     29 // The mnemonics specify the meaning of each tensor dimension sorted from
     30 // largest to smallest memory stride.
     31 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
     32 enum TensorFormat {
     33   // FORMAT_NHWC is the default format in TensorFlow.
     34   FORMAT_NHWC = 0,
     35 
     36   // FORMAT_NCHW often improves performance on GPUs.
     37   FORMAT_NCHW = 1,
     38 
     39   // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
     40   // int8 convolution and fused convolution. It is laid out in the same order
     41   // as NCHW, except that the size of the Channels dimension is divided by 4,
     42   // and a new dimension of size 4 is appended, which packs 4 adjacent channel
     43   // activations for the same pixel into an int32. Thus an NCHW format tensor
     44   // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
     45   // NCHW_VECT_C format.
     46   // A pre-condition of this format is that C must be a multiple of 4.
     47   FORMAT_NCHW_VECT_C = 2,
     48 };
     49 
     50 // Tensor format for convolutional filters.
     51 // The mnemonics specify the meaning of each tensor dimension sorted
     52 // from largest to smallest memory stride.
     53 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
     54 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
     55 enum FilterTensorFormat {
     56   // FORMAT_HWIO is the default filter format in TensorFlow.
     57   // Ops that do not have a 'filter_format' attribute will assume this format.
     58   FORMAT_HWIO = 0,
     59 
     60   // FORMAT_OIHW often improves performance on GPUs.
     61   FORMAT_OIHW = 1,
     62 
     63   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
     64   // int8 convolution and fused convolution. It is analagous to the NCHW_VECT_C
     65   // data format. It is laid out in the same order as OIHW, except that the size
     66   // of the Input Channels dimension is divided by 4, and a new dimension of
     67   // size 4 is appended, which packs 4 adjacent input channel weights into an
     68   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
     69   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
     70   // A pre-condition of this format is that I must be a multiple of 4.
     71   FORMAT_OIHW_VECT_I = 2,
     72 };
     73 
     74 // Parse tensor format from the given string.
     75 // Return true if the parsing succeeds, and false if it fails.
     76 bool FormatFromString(const string& format_str, TensorFormat* format);
     77 
     78 // Parse tensor format from the given string.
     79 // Return true if the parsing succeeds, and false if it fails.
     80 bool FilterFormatFromString(const string& format_str,
     81                             FilterTensorFormat* format);
     82 
     83 // Convert a tensor format into string.
     84 string ToString(TensorFormat format);
     85 
     86 // Convert a filter tensor format into string.
     87 string ToString(FilterTensorFormat format);
     88 
     89 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
     90 // format 'format'.
     91 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
     92   if (format == FORMAT_NCHW_VECT_C) {
     93     return num_dims - 3;  // Exclude N,C,InnerC.
     94   } else {
     95     return num_dims - 2;  // Exclude N,C.
     96   }
     97 }
     98 
     99 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
    100   if (format == FORMAT_OIHW_VECT_I) {
    101     return num_dims - 3;  // Exclude O,I,InnerI.
    102   } else {
    103     return num_dims - 2;  // Exclude O,I.
    104   }
    105 }
    106 
    107 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
    108 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
    109 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
    110                                         TensorFormat format) {
    111   if (format == FORMAT_NCHW_VECT_C) {
    112     return num_spatial_dims + 3;  // Include N,C,InnerC.
    113   } else {
    114     return num_spatial_dims + 2;  // Include N,C.
    115   }
    116 }
    117 
    118 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
    119 // filter tensor format 'format'.
    120 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
    121                                               FilterTensorFormat format) {
    122   if (format == FORMAT_OIHW_VECT_I) {
    123     return num_spatial_dims + 3;  // Include O,I,InnerI.
    124   } else {
    125     return num_spatial_dims + 2;  // Include O,I.
    126   }
    127 }
    128 
    129 // Returns the index of the batch dimension.
    130 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
    131   switch (format) {
    132     case FORMAT_NHWC:
    133     case FORMAT_NCHW:
    134     case FORMAT_NCHW_VECT_C:
    135       return 0;
    136     default:
    137       LOG(FATAL) << "Unknown format " << format;
    138       return -1;  // Avoid compiler warning about missing return value
    139   }
    140 }
    141 
    142 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
    143 // the index of the outer feature dimension (i.e. dimension 1, whose size would
    144 // be num_features / 4 in this case).
    145 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
    146   switch (format) {
    147     case FORMAT_NHWC:
    148       return num_dims - 1;
    149     case FORMAT_NCHW:
    150     case FORMAT_NCHW_VECT_C:
    151       return 1;
    152     default:
    153       LOG(FATAL) << "Unknown format " << format;
    154       return -1;  // Avoid compiler warning about missing return value
    155   }
    156 }
    157 
    158 // Returns the index of the inner feature dimension.
    159 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
    160   DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
    161   return num_dims - 1;
    162 }
    163 
    164 // Returns the index of the `dim`-th spatial dimension.
    165 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
    166                                     int dim) {
    167   CHECK(dim >= 0 && dim < GetTensorSpatialDims(num_dims, format))
    168       << dim << " " << num_dims << " " << ToString(format);
    169   switch (format) {
    170     case FORMAT_NHWC:
    171       return dim + 1;
    172     case FORMAT_NCHW:
    173     case FORMAT_NCHW_VECT_C:
    174       return dim + 2;
    175     default:
    176       LOG(FATAL) << "Unknown format " << format;
    177       return -1;  // Avoid compiler warning about missing return value
    178   }
    179 }
    180 
    181 // Returns the index of the `dim`-th spatial dimension.
    182 inline int GetFilterTensorSpatialDimIndex(int num_dims,
    183                                           FilterTensorFormat format, int dim) {
    184   CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
    185       << dim << " " << num_dims << " " << ToString(format);
    186   switch (format) {
    187     case FORMAT_HWIO:
    188       return dim;
    189     case FORMAT_OIHW:
    190     case FORMAT_OIHW_VECT_I:
    191       return dim + 2;
    192     default:
    193       LOG(FATAL) << "Unknown format " << format;
    194       return -1;  // Avoid compiler warning about missing return value
    195   }
    196 }
    197 
    198 // Returns the index of the inner input channels dimension.
    199 inline int GetFilterTensorInnerInputChannelsDimIndex(
    200     int num_dims, FilterTensorFormat format) {
    201   DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
    202   return num_dims - 1;
    203 }
    204 
    205 // Returns the index of the input channels dimension.
    206 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
    207 // outer input channel (i.e. 1), which holds num_input_channels / 4.
    208 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
    209                                                 FilterTensorFormat format) {
    210   switch (format) {
    211     case FORMAT_HWIO:
    212       return num_dims - 2;
    213     case FORMAT_OIHW:
    214     case FORMAT_OIHW_VECT_I:
    215       return 1;
    216     default:
    217       LOG(FATAL) << "Unknown format " << format;
    218       return -1;  // Avoid compiler warning about missing return value
    219   }
    220 }
    221 
    222 // Returns the index of the output channels dimension.
    223 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
    224                                                  FilterTensorFormat format) {
    225   switch (format) {
    226     case FORMAT_HWIO:
    227       return num_dims - 1;
    228     case FORMAT_OIHW:
    229     case FORMAT_OIHW_VECT_I:
    230       return 0;
    231     default:
    232       LOG(FATAL) << "Unknown format " << format;
    233       return -1;  // Avoid compiler warning about missing return value
    234   }
    235 }
    236 
    237 // TODO(pauldonnelly): Replace these tensor dimension index functions with
    238 // constant structs to improve performance and reduce code size in Compute()
    239 // functions.
    240 
    241 // Return the dimension index for the specified 'dimension' of the specified
    242 // data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
    243 // 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
    244 // '0',  .. (NUM_SPATIAL_DIMS-1)..
    245 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
    246 // the outer channel dimension (i.e. 1).
    247 template <int NUM_SPATIAL_DIMS>
    248 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
    249   if (format == FORMAT_NHWC) {
    250     // clang-format off
    251     switch (dimension) {
    252       case 'N': return 0;
    253       case '0': return 1;
    254       case '1': return 2;
    255       case '2': return 3;
    256       case 'H': return NUM_SPATIAL_DIMS - 1;
    257       case 'W': return NUM_SPATIAL_DIMS;
    258       case 'C': return NUM_SPATIAL_DIMS + 1;
    259       default:
    260         LOG(FATAL) << "Invalid dimension: " << dimension;
    261         return -1;  // Avoid compiler warning about missing return value
    262     }
    263   } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
    264     switch (dimension) {
    265       case 'N': return 0;
    266       case 'C': return 1;
    267       case '0': return 2;
    268       case '1': return 3;
    269       case '2': return 4;
    270       case 'H': return NUM_SPATIAL_DIMS;
    271       case 'W': return NUM_SPATIAL_DIMS + 1;
    272       default:
    273         LOG(FATAL) << "Invalid dimension: " << dimension;
    274         return -1;  // Avoid compiler warning about missing return value
    275     }
    276   } else {
    277     LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
    278     return -1;  // Avoid compiler warning about missing return value
    279   }
    280   // clang-format on
    281 }
    282 
    283 // Return the dimension index for the specified 'dimension' of the specified
    284 // 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
    285 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
    286 // numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
    287 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
    288 // outer input channels dimension (i.e. 1).
    289 template <int NUM_SPATIAL_DIMS>
    290 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
    291                              char dimension) {
    292   // clang-format off
    293   if (filter_tensor_format == FORMAT_HWIO) {
    294     switch (dimension) {
    295       case '0': return 0;
    296       case '1': return 1;
    297       case '2': return 2;
    298       case 'H': return NUM_SPATIAL_DIMS - 2;
    299       case 'W': return NUM_SPATIAL_DIMS - 1;
    300       case 'I': return NUM_SPATIAL_DIMS;
    301       case 'O': return NUM_SPATIAL_DIMS + 1;
    302       default:
    303         LOG(FATAL) << "Invalid dimension: " << dimension;
    304         return -1;  // Avoid compiler warning about missing return value
    305     }
    306   } else if (filter_tensor_format == FORMAT_OIHW ||
    307              filter_tensor_format == FORMAT_OIHW_VECT_I) {
    308     switch (dimension) {
    309       case 'O': return 0;
    310       case 'I': return 1;
    311       case '0': return 2;
    312       case '1': return 3;
    313       case '2': return 4;
    314       case 'H': return NUM_SPATIAL_DIMS;
    315       case 'W': return NUM_SPATIAL_DIMS + 1;
    316       default:
    317         LOG(FATAL) << "Invalid dimension: " << dimension;
    318         return -1;  // Avoid compiler warning about missing return value
    319     }
    320   } else {
    321     LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
    322     return -1;  // Avoid compiler warning about missing return value
    323   }
    324   // clang-format on
    325 }
    326 
    327 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
    328   return GetTensorDimIndex<2>(format, dimension);
    329 }
    330 
    331 // Return the element from 'dimension_attributes' that corresponds to the
    332 // specified 'dimension' according to 'tensor_format'.
    333 template <typename T>
    334 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
    335                TensorFormat tensor_format, char dimension) {
    336   int index =
    337       (GetTensorSpatialDims(dimension_attributes.size(), tensor_format) == 3)
    338           ? GetTensorDimIndex<3>(tensor_format, dimension)
    339           : GetTensorDimIndex<2>(tensor_format, dimension);
    340   CHECK(index >= 0 && index < dimension_attributes.size())
    341       << "Invalid index from the dimension: " << index << ", " << tensor_format
    342       << ", " << dimension;
    343   return dimension_attributes[index];
    344 }
    345 
    346 // Return the element from 'dimension_attribute' that corresponds to the
    347 // specified 'dimension' according to 'filter_tensor_format'.
    348 template <typename T>
    349 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
    350                FilterTensorFormat filter_tensor_format, char dimension) {
    351   int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
    352                                           filter_tensor_format) == 3)
    353                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
    354                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
    355   CHECK(index >= 0 && index < dimension_attribute.size())
    356       << "Invalid index from the dimension: " << index << ", "
    357       << filter_tensor_format << ", " << dimension;
    358   return dimension_attribute[index];
    359 }
    360 
    361 template <typename T>
    362 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
    363                char dimension) {
    364   return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
    365 }
    366 
    367 // Return the size of the specified 'dimension' within 'tensor_shape'
    368 // according to 'tensor_format'.
    369 inline int64 GetTensorDim(const TensorShape& tensor_shape,
    370                           TensorFormat tensor_format, char dimension) {
    371   return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
    372                       tensor_format, dimension);
    373 }
    374 
    375 // Return the size of the specified 'dimension' within 'tensor_shape'
    376 // according to 'tensor_filter_format'.
    377 inline int64 GetFilterDim(const TensorShape& tensor_shape,
    378                           FilterTensorFormat tensor_filter_format,
    379                           char dimension) {
    380   return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
    381                       tensor_filter_format, dimension);
    382 }
    383 
    384 // Return the size of the specified 'dimension' of 'tensor' according to
    385 // 'tensor_format'.
    386 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
    387                           char dimension) {
    388   return GetTensorDim(tensor.shape(), tensor_format, dimension);
    389 }
    390 
    391 // Return the size of the specified 'dimension' of 'tensor' according to
    392 // 'filter_tensor_format'.
    393 inline int64 GetFilterDim(const Tensor& tensor,
    394                           FilterTensorFormat filter_tensor_format,
    395                           char dimension) {
    396   return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
    397 }
    398 
    399 // Return the string that specifies the data format for convnet operations.
    400 string GetConvnetDataFormatAttrString();
    401 string GetConvnet3dDataFormatAttrString();
    402 
    403 // Return the string that specifies the filter format for convnet operations.
    404 string GetConvnetFilterFormatAttrString();
    405 string GetConvnet3dFilterFormatAttrString();
    406 
    407 // Return a tensor shape for the given format. Works for both 2D and 3D
    408 // operations. If format is FORMAT_NCHW_VECT_C, the output TensorShape has rank
    409 // spatial.size()+3 (N,C,spatial,InnerC); otherwise, it has rank
    410 // spatial.size()+2 (e.g. N,C,spatial or N,spatial,C).
    411 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
    412                                    gtl::ArraySlice<int64> spatial, int64 C) {
    413   const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
    414   gtl::InlinedVector<int64, 6> dim_sizes(dims);
    415   dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
    416   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
    417     dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
    418   }
    419 
    420   int feature_index = GetTensorFeatureDimIndex(dims, format);
    421   if (format == FORMAT_NCHW_VECT_C) {
    422     CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
    423                        << C;
    424     dim_sizes[feature_index] = C / 4;
    425     dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
    426   } else {
    427     dim_sizes[feature_index] = C;
    428   }
    429   return TensorShape(dim_sizes);
    430 }
    431 
    432 // Return a tensor shape of the specified 'format', and dimensions.
    433 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
    434 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
    435 // it has spatial.size() + 2 dimensions.
    436 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
    437                                                gtl::ArraySlice<int64> spatial,
    438                                                int64 I, int64 O) {
    439   const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
    440   gtl::InlinedVector<int64, 6> dim_sizes(dims);
    441   dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
    442   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
    443     dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
    444   }
    445 
    446   if (format == FORMAT_OIHW_VECT_I) {
    447     CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
    448                        << I;
    449     I /= 4;
    450     dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
    451   }
    452   dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
    453   return TensorShape(dim_sizes);
    454 }
    455 
    456 // Return a tensor shape of the specified 'format', and dimensions.
    457 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
    458                                    int64 W, int64 C) {
    459   return ShapeFromFormat(format, N, {H, W}, C);
    460 }
    461 
    462 // Return a filter tensor shape of the specified 'format', and dimensions.
    463 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
    464                                                int64 H, int64 W, int64 I,
    465                                                int64 O) {
    466   return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
    467 }
    468 
    469 // Returns a copy of the specified tensor 'src_shape' converted from
    470 // 'src_format' to 'dst_format'.
    471 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
    472                                    const TensorShape& src_shape,
    473                                    TensorFormat src_format) {
    474   if (src_format == dst_format) {
    475     return src_shape;
    476   }
    477 
    478   const int64 batch = GetTensorDim(src_shape, src_format, 'N');
    479   const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
    480                          (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
    481 
    482   if (GetTensorSpatialDims(src_shape.dims(), src_format) == 3) {
    483     return ShapeFromFormat(dst_format, batch,
    484                            {{GetTensorDim(src_shape, src_format, '0'),
    485                              GetTensorDim(src_shape, src_format, '1'),
    486                              GetTensorDim(src_shape, src_format, '2')}},
    487                            channels);
    488   }
    489 
    490   return ShapeFromFormat(dst_format, batch,
    491                          {{GetTensorDim(src_shape, src_format, 'H'),
    492                            GetTensorDim(src_shape, src_format, 'W')}},
    493                          channels);
    494 }
    495 
    496 // Returns a copy of the specified filter tensor 'src_shape' converted from
    497 // 'src_filter_format' to 'dst_filter_format'.
    498 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
    499                                          const TensorShape& src_shape,
    500                                          FilterTensorFormat src_filter_format) {
    501   if (src_filter_format == dst_filter_format) {
    502     return src_shape;
    503   }
    504 
    505   const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
    506   const int64 input_channels =
    507       GetFilterDim(src_shape, src_filter_format, 'I') *
    508       (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
    509 
    510   if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
    511     return ShapeFromFilterTensorFormat(
    512         dst_filter_format,
    513         {{GetFilterDim(src_shape, src_filter_format, '0'),
    514           GetFilterDim(src_shape, src_filter_format, '1'),
    515           GetFilterDim(src_shape, src_filter_format, '2')}},
    516         input_channels, output_channels);
    517   }
    518 
    519   return ShapeFromFilterTensorFormat(
    520       dst_filter_format,
    521       {{GetFilterDim(src_shape, src_filter_format, 'H'),
    522         GetFilterDim(src_shape, src_filter_format, 'W')}},
    523       input_channels, output_channels);
    524 }
    525 
    526 }  // namespace tensorflow
    527 
    528 #endif  // TENSORFLOW_UTIL_TENSOR_FORMAT_H_
    529