Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
     17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
     18 
     19 #include <array>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 // Tensor format for input/output activations used in convolution operations.
     29 // The mnemonics specify the meaning of each tensor dimension sorted from
     30 // largest to smallest memory stride.
     31 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
     32 // TODO(pauldonnelly): It would probably be better to switch to a registration
     33 // process for tensor formats, so specialized formats could be defined more
     34 // locally to where they are used.
     35 enum TensorFormat {
     36   // FORMAT_NHWC is the default format in TensorFlow.
     37   FORMAT_NHWC = 0,
     38 
     39   // FORMAT_NCHW often improves performance on GPUs.
     40   FORMAT_NCHW = 1,
     41 
     42   // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
     43   // int8 convolution and fused convolution. It is laid out in the same order
     44   // as NCHW, except that the size of the Channels dimension is divided by 4,
     45   // and a new dimension of size 4 is appended, which packs 4 adjacent channel
     46   // activations for the same pixel into an int32. Thus an NCHW format tensor
     47   // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
     48   // NCHW_VECT_C format.
     49   // A pre-condition of this format is that C must be a multiple of 4.
     50   FORMAT_NCHW_VECT_C = 2,
     51 
     52   // Similar to NHWC, but the size of the W dimension is divided by 4, and a
     53   // new dimension of size 4 is appended, which packs 4 adjacent activations
     54   // in the width dimension.
     55   FORMAT_NHWC_VECT_W = 3,
     56 
     57   // Note: although the current code in this file assumes VECT_C and VECT_W
     58   // enums imply int8x4 vectors, this should not be relied upon.
     59   // In the future we may change the meaning of these enums to include vectors
     60   // of other types such as int16x2, with op implementations automatically
     61   // determining which format is implied based on the datatype.
     62 
     63   // FORMAT_HWNC is for TPUs.
     64   FORMAT_HWNC = 4,
     65 
     66   // FORMAT_HWCN is for TPUs.
     67   FORMAT_HWCN = 5,
     68 };
     69 
     70 // Tensor format for convolutional filters.
     71 // The mnemonics specify the meaning of each tensor dimension sorted
     72 // from largest to smallest memory stride.
     73 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
     74 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
     75 enum FilterTensorFormat {
     76   // FORMAT_HWIO is the default filter format in TensorFlow.
     77   // Ops that do not have a 'filter_format' attribute will assume this format.
     78   FORMAT_HWIO = 0,
     79 
     80   // FORMAT_OIHW often improves performance on GPUs.
     81   FORMAT_OIHW = 1,
     82 
     83   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
     84   // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
     85   // data format. It is laid out in the same order as OIHW, except that the size
     86   // of the Input Channels dimension is divided by 4, and a new dimension of
     87   // size 4 is appended, which packs 4 adjacent input channel weights into an
     88   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
     89   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
     90   // A pre-condition of this format is that I must be a multiple of 4.
     91   FORMAT_OIHW_VECT_I = 2,
     92 };
     93 
     94 // Parse tensor format from the given string.
     95 // Return true if the parsing succeeds, and false if it fails.
     96 bool FormatFromString(const string& format_str, TensorFormat* format);
     97 
     98 // Parse tensor format from the given string.
     99 // Return true if the parsing succeeds, and false if it fails.
    100 bool FilterFormatFromString(const string& format_str,
    101                             FilterTensorFormat* format);
    102 
    103 // Convert a tensor format into string.
    104 string ToString(TensorFormat format);
    105 
    106 // Convert a filter tensor format into string.
    107 string ToString(FilterTensorFormat format);
    108 
    109 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
    110 // format 'format'.
    111 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
    112   switch (format) {
    113     case FORMAT_NHWC:
    114     case FORMAT_NCHW:
    115     case FORMAT_HWNC:
    116     case FORMAT_HWCN:
    117       return num_dims - 2;  // Exclude N,C.
    118     case FORMAT_NCHW_VECT_C:
    119     case FORMAT_NHWC_VECT_W:
    120       // Note: the VECT_W is not counted as an independent spatial dim here,
    121       // since it just a component of the width dimension.
    122       return num_dims - 3;  // Exclude N,C,VectDim.
    123   }
    124 }
    125 
    126 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
    127   if (format == FORMAT_OIHW_VECT_I) {
    128     return num_dims - 3;  // Exclude O,I,InnerI.
    129   } else {
    130     return num_dims - 2;  // Exclude O,I.
    131   }
    132 }
    133 
    134 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
    135 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
    136 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
    137                                         TensorFormat format) {
    138   switch (format) {
    139     case FORMAT_NHWC:
    140     case FORMAT_NCHW:
    141     case FORMAT_HWNC:
    142     case FORMAT_HWCN:
    143       return num_spatial_dims + 2;  // Include N,C.
    144     case FORMAT_NCHW_VECT_C:
    145     case FORMAT_NHWC_VECT_W:
    146       return num_spatial_dims + 3;  // Include N,C,VectDim.
    147   }
    148 }
    149 
    150 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
    151 // filter tensor format 'format'.
    152 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
    153                                               FilterTensorFormat format) {
    154   if (format == FORMAT_OIHW_VECT_I) {
    155     return num_spatial_dims + 3;  // Include O,I,InnerI.
    156   } else {
    157     return num_spatial_dims + 2;  // Include O,I.
    158   }
    159 }
    160 
    161 // Returns the index of the batch dimension.
    162 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
    163   switch (format) {
    164     case FORMAT_NHWC:
    165     case FORMAT_NCHW:
    166     case FORMAT_NCHW_VECT_C:
    167     case FORMAT_NHWC_VECT_W:
    168       return 0;
    169     case FORMAT_HWNC:
    170       return num_dims - 2;
    171     case FORMAT_HWCN:
    172       return num_dims - 1;
    173     default:
    174       LOG(FATAL) << "Unknown format " << format;
    175       return -1;  // Avoid compiler warning about missing return value
    176   }
    177 }
    178 
    179 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
    180 // the index of the outer feature dimension (i.e. dimension 1, whose size would
    181 // be num_features / 4 in this case).
    182 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
    183   switch (format) {
    184     case FORMAT_NHWC:
    185     case FORMAT_HWNC:
    186       return num_dims - 1;
    187     case FORMAT_NHWC_VECT_W:
    188     case FORMAT_HWCN:
    189       return num_dims - 2;
    190     case FORMAT_NCHW:
    191     case FORMAT_NCHW_VECT_C:
    192       return 1;
    193     default:
    194       LOG(FATAL) << "Unknown format " << format;
    195       return -1;  // Avoid compiler warning about missing return value
    196   }
    197 }
    198 
    199 // Returns the index of the inner feature dimension.
    200 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
    201   DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
    202   return num_dims - 1;
    203 }
    204 
    205 // Returns the index of the inner width dimension.
    206 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
    207   DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
    208   return num_dims - 1;
    209 }
    210 
    211 // Returns the dimension index of the specified 'spatial_dim' within an
    212 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
    213 // the index of the outer width dimension (i.e. dimension 2, whose size would
    214 // be width / 4 in this case).
    215 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
    216                                     int spatial_dim) {
    217   CHECK(spatial_dim >= 0 &&
    218         spatial_dim < GetTensorSpatialDims(num_dims, format))
    219       << spatial_dim << " " << num_dims << " " << ToString(format);
    220   switch (format) {
    221     case FORMAT_NHWC:
    222     case FORMAT_NHWC_VECT_W:
    223       return spatial_dim + 1;
    224     case FORMAT_NCHW:
    225     case FORMAT_NCHW_VECT_C:
    226       return spatial_dim + 2;
    227     case FORMAT_HWNC:
    228     case FORMAT_HWCN:
    229       return spatial_dim;
    230     default:
    231       LOG(FATAL) << "Unknown format " << format;
    232       return -1;  // Avoid compiler warning about missing return value
    233   }
    234 }
    235 
    236 inline int GetFilterTensorSpatialDimIndex(int num_dims,
    237                                           FilterTensorFormat format, int dim) {
    238   CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
    239       << dim << " " << num_dims << " " << ToString(format);
    240   switch (format) {
    241     case FORMAT_HWIO:
    242       return dim;
    243     case FORMAT_OIHW:
    244     case FORMAT_OIHW_VECT_I:
    245       return dim + 2;
    246     default:
    247       LOG(FATAL) << "Unknown format " << format;
    248       return -1;  // Avoid compiler warning about missing return value
    249   }
    250 }
    251 
    252 // Returns the index of the inner input channels dimension.
    253 inline int GetFilterTensorInnerInputChannelsDimIndex(
    254     int num_dims, FilterTensorFormat format) {
    255   DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
    256   return num_dims - 1;
    257 }
    258 
    259 // Returns the index of the input channels dimension.
    260 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
    261 // outer input channel (i.e. 1), which holds num_input_channels / 4.
    262 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
    263                                                 FilterTensorFormat format) {
    264   switch (format) {
    265     case FORMAT_HWIO:
    266       return num_dims - 2;
    267     case FORMAT_OIHW:
    268     case FORMAT_OIHW_VECT_I:
    269       return 1;
    270     default:
    271       LOG(FATAL) << "Unknown format " << format;
    272       return -1;  // Avoid compiler warning about missing return value
    273   }
    274 }
    275 
    276 // Returns the index of the output channels dimension.
    277 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
    278                                                  FilterTensorFormat format) {
    279   switch (format) {
    280     case FORMAT_HWIO:
    281       return num_dims - 1;
    282     case FORMAT_OIHW:
    283     case FORMAT_OIHW_VECT_I:
    284       return 0;
    285     default:
    286       LOG(FATAL) << "Unknown format " << format;
    287       return -1;  // Avoid compiler warning about missing return value
    288   }
    289 }
    290 
    291 // TODO(pauldonnelly): Replace these tensor dimension index functions with
    292 // constant structs to improve performance and reduce code size in Compute()
    293 // functions.
    294 
    295 // Return the dimension index for the specified 'dimension' of the specified
    296 // data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
    297 // 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
    298 // '0',  .. (NUM_SPATIAL_DIMS-1)..
    299 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
    300 // the outer channel dimension (i.e. 1).
    301 template <int NUM_SPATIAL_DIMS>
    302 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
    303   if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
    304     // clang-format off
    305     switch (dimension) {
    306       case 'N': return 0;
    307       case '0': return 1;
    308       case '1': return 2;
    309       case '2': return 3;
    310       case 'H': return NUM_SPATIAL_DIMS - 1;
    311       case 'W': return NUM_SPATIAL_DIMS;
    312       case 'C': return NUM_SPATIAL_DIMS + 1;
    313       default:
    314         LOG(FATAL) << "Invalid dimension: " << dimension;
    315         return -1;  // Avoid compiler warning about missing return value
    316     }
    317   } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
    318     switch (dimension) {
    319       case 'N': return 0;
    320       case 'C': return 1;
    321       case '0': return 2;
    322       case '1': return 3;
    323       case '2': return 4;
    324       case 'H': return NUM_SPATIAL_DIMS;
    325       case 'W': return NUM_SPATIAL_DIMS + 1;
    326       default:
    327         LOG(FATAL) << "Invalid dimension: " << dimension;
    328         return -1;  // Avoid compiler warning about missing return value
    329     }
    330   } else if (format == FORMAT_HWNC) {
    331     switch (dimension) {
    332       case '0': return 0;
    333       case '1': return 1;
    334       case '2': return 2;
    335       case 'H': return NUM_SPATIAL_DIMS - 2;
    336       case 'W': return NUM_SPATIAL_DIMS - 1;
    337       case 'N': return NUM_SPATIAL_DIMS;
    338       case 'C': return NUM_SPATIAL_DIMS + 1;
    339       default:
    340         LOG(FATAL) << "Invalid dimension: " << dimension;
    341         return -1;  // Avoid compiler warning about missing return value
    342     }
    343   } else if (format == FORMAT_HWCN) {
    344     switch (dimension) {
    345       case '0': return 0;
    346       case '1': return 1;
    347       case '2': return 2;
    348       case 'H': return NUM_SPATIAL_DIMS - 2;
    349       case 'W': return NUM_SPATIAL_DIMS - 1;
    350       case 'C': return NUM_SPATIAL_DIMS;
    351       case 'N': return NUM_SPATIAL_DIMS + 1;
    352       default:
    353         LOG(FATAL) << "Invalid dimension: " << dimension;
    354         return -1;  // Avoid compiler warning about missing return value
    355     }
    356   } else {
    357     LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
    358     return -1;  // Avoid compiler warning about missing return value
    359   }
    360   // clang-format on
    361 }
    362 
    363 // Return the dimension index for the specified 'dimension' of the specified
    364 // 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
    365 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
    366 // numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
    367 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
    368 // outer input channels dimension (i.e. 1).
    369 template <int NUM_SPATIAL_DIMS>
    370 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
    371                              char dimension) {
    372   // clang-format off
    373   if (filter_tensor_format == FORMAT_HWIO) {
    374     switch (dimension) {
    375       case '0': return 0;
    376       case '1': return 1;
    377       case '2': return 2;
    378       case 'H': return NUM_SPATIAL_DIMS - 2;
    379       case 'W': return NUM_SPATIAL_DIMS - 1;
    380       case 'I': return NUM_SPATIAL_DIMS;
    381       case 'O': return NUM_SPATIAL_DIMS + 1;
    382       default:
    383         LOG(FATAL) << "Invalid dimension: " << dimension;
    384         return -1;  // Avoid compiler warning about missing return value
    385     }
    386   } else if (filter_tensor_format == FORMAT_OIHW ||
    387              filter_tensor_format == FORMAT_OIHW_VECT_I) {
    388     switch (dimension) {
    389       case 'O': return 0;
    390       case 'I': return 1;
    391       case '0': return 2;
    392       case '1': return 3;
    393       case '2': return 4;
    394       case 'H': return NUM_SPATIAL_DIMS;
    395       case 'W': return NUM_SPATIAL_DIMS + 1;
    396       default:
    397         LOG(FATAL) << "Invalid dimension: " << dimension;
    398         return -1;  // Avoid compiler warning about missing return value
    399     }
    400   } else {
    401     LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
    402     return -1;  // Avoid compiler warning about missing return value
    403   }
    404   // clang-format on
    405 }
    406 
    407 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
    408   return GetTensorDimIndex<2>(format, dimension);
    409 }
    410 
    411 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
    412                                int num_total_dims) {
    413   int32 index = (GetTensorSpatialDims(num_total_dims, format) == 3)
    414                     ? GetTensorDimIndex<3>(format, dimension)
    415                     : GetTensorDimIndex<2>(format, dimension);
    416   CHECK(index >= 0 && index < num_total_dims)  // Crash OK.
    417       << "Invalid index from the dimension: " << index << ", " << format << ", "
    418       << dimension;
    419   return index;
    420 }
    421 
    422 // Return the element from 'dimension_attributes' that corresponds to the
    423 // specified 'dimension' according to 'tensor_format'.
    424 template <typename T>
    425 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
    426                TensorFormat tensor_format, char dimension) {
    427   int index =
    428       GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
    429   return dimension_attributes[index];
    430 }
    431 
    432 // Return the element from 'dimension_attribute' that corresponds to the
    433 // specified 'dimension' according to 'filter_tensor_format'.
    434 template <typename T>
    435 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
    436                FilterTensorFormat filter_tensor_format, char dimension) {
    437   int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
    438                                           filter_tensor_format) == 3)
    439                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
    440                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
    441   CHECK(index >= 0 && index < dimension_attribute.size())
    442       << "Invalid index from the dimension: " << index << ", "
    443       << filter_tensor_format << ", " << dimension;
    444   return dimension_attribute[index];
    445 }
    446 
    447 template <typename T>
    448 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
    449                char dimension) {
    450   return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
    451 }
    452 
    453 // Return the size of the specified 'dimension' within 'tensor_shape'
    454 // according to 'tensor_format'.
    455 inline int64 GetTensorDim(const TensorShape& tensor_shape,
    456                           TensorFormat tensor_format, char dimension) {
    457   return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
    458                       tensor_format, dimension);
    459 }
    460 
    461 // Return the size of the specified 'dimension' within 'tensor_shape'
    462 // according to 'tensor_filter_format'.
    463 inline int64 GetFilterDim(const TensorShape& tensor_shape,
    464                           FilterTensorFormat tensor_filter_format,
    465                           char dimension) {
    466   return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
    467                       tensor_filter_format, dimension);
    468 }
    469 
    470 // Return the size of the specified 'dimension' of 'tensor' according to
    471 // 'tensor_format'.
    472 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
    473                           char dimension) {
    474   return GetTensorDim(tensor.shape(), tensor_format, dimension);
    475 }
    476 
    477 // Return the size of the specified 'dimension' of 'tensor' according to
    478 // 'filter_tensor_format'.
    479 inline int64 GetFilterDim(const Tensor& tensor,
    480                           FilterTensorFormat filter_tensor_format,
    481                           char dimension) {
    482   return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
    483 }
    484 
    485 inline void GetExplicitPaddingForDim(
    486     const std::vector<int64>& explicit_paddings, TensorFormat tensor_format,
    487     char dimension, int64* padding_before, int64* padding_after) {
    488   int index =
    489       GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
    490   *padding_before = explicit_paddings[2 * index];
    491   *padding_after = explicit_paddings[2 * index + 1];
    492 }
    493 
    494 // Return the string that specifies the data format for convnet operations.
    495 string GetConvnetDataFormatAttrString();
    496 string GetConvnet3dDataFormatAttrString();
    497 
    498 // Return the string that specifies the filter format for convnet operations.
    499 string GetConvnetFilterFormatAttrString();
    500 string GetConvnet3dFilterFormatAttrString();
    501 string GetConvnetDataFormat2D3DAttrString();
    502 
    503 // Returns a tensor shape for the specified format and dimension sizes.
    504 // Works for both 2D and 3D operations. The output shapes are as follows:
    505 // FORMAT_NHWC:        (N, spatial, C); rank = spatial.size() + 2
    506 // FORMAT_NCHW:        (N, C, spatial); rank = spatial.size() + 2
    507 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
    508 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
    509 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
    510                                    gtl::ArraySlice<int64> spatial, int64 C) {
    511   const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
    512   gtl::InlinedVector<int64, 6> dim_sizes(dims);
    513   dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
    514   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
    515     auto dim_size = spatial[dim];
    516     if (format == FORMAT_NHWC_VECT_W &&
    517         static_cast<size_t>(dim) == spatial.size() - 1) {
    518       CHECK_EQ(0, dim_size % 4)
    519           << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
    520           << dim_size;
    521       dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
    522       dim_size /= 4;
    523     }
    524     dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
    525   }
    526 
    527   int feature_index = GetTensorFeatureDimIndex(dims, format);
    528   if (format == FORMAT_NCHW_VECT_C) {
    529     CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
    530                        << C;
    531     C /= 4;
    532     dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
    533   }
    534   dim_sizes[feature_index] = C;
    535   return TensorShape(dim_sizes);
    536 }
    537 
    538 // Return a tensor shape of the specified 'format', and dimensions.
    539 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
    540 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
    541 // it has spatial.size() + 2 dimensions.
    542 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
    543                                                gtl::ArraySlice<int64> spatial,
    544                                                int64 I, int64 O) {
    545   const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
    546   gtl::InlinedVector<int64, 6> dim_sizes(dims);
    547   dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
    548   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
    549     dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
    550   }
    551 
    552   if (format == FORMAT_OIHW_VECT_I) {
    553     CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
    554                        << I;
    555     I /= 4;
    556     dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
    557   }
    558   dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
    559   return TensorShape(dim_sizes);
    560 }
    561 
    562 // Return a tensor shape of the specified 'format', and dimensions.
    563 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
    564                                    int64 W, int64 C) {
    565   return ShapeFromFormat(format, N, {H, W}, C);
    566 }
    567 
    568 // Return a filter tensor shape of the specified 'format', and dimensions.
    569 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
    570                                                int64 H, int64 W, int64 I,
    571                                                int64 O) {
    572   return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
    573 }
    574 
    575 // Returns a copy of the specified tensor 'src_shape' converted from
    576 // 'src_format' to 'dst_format'.
    577 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
    578                                    const TensorShape& src_shape,
    579                                    TensorFormat src_format) {
    580   if (src_format == dst_format) {
    581     return src_shape;
    582   }
    583 
    584   const int64 batch = GetTensorDim(src_shape, src_format, 'N');
    585   const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
    586                          (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
    587   const int num_src_spatial_dims =
    588       GetTensorSpatialDims(src_shape.dims(), src_format);
    589   std::vector<int64> spatial_dims(num_src_spatial_dims);
    590   for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
    591     spatial_dims[spatial_dim] =
    592         gtl::ArraySlice<int64>(src_shape.dim_sizes())[GetTensorSpatialDimIndex(
    593             src_shape.dims(), src_format, spatial_dim)];
    594   }
    595   if (src_format == FORMAT_NHWC_VECT_W) {
    596     spatial_dims[num_src_spatial_dims - 1] *= 4;
    597   }
    598   return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
    599 }
    600 
    601 // Returns a copy of the specified filter tensor 'src_shape' converted from
    602 // 'src_filter_format' to 'dst_filter_format'.
    603 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
    604                                          const TensorShape& src_shape,
    605                                          FilterTensorFormat src_filter_format) {
    606   if (src_filter_format == dst_filter_format) {
    607     return src_shape;
    608   }
    609 
    610   const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
    611   const int64 input_channels =
    612       GetFilterDim(src_shape, src_filter_format, 'I') *
    613       (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
    614 
    615   if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
    616     return ShapeFromFilterTensorFormat(
    617         dst_filter_format,
    618         {{GetFilterDim(src_shape, src_filter_format, '0'),
    619           GetFilterDim(src_shape, src_filter_format, '1'),
    620           GetFilterDim(src_shape, src_filter_format, '2')}},
    621         input_channels, output_channels);
    622   }
    623 
    624   return ShapeFromFilterTensorFormat(
    625       dst_filter_format,
    626       {{GetFilterDim(src_shape, src_filter_format, 'H'),
    627         GetFilterDim(src_shape, src_filter_format, 'W')}},
    628       input_channels, output_channels);
    629 }
    630 
    631 }  // namespace tensorflow
    632 
    633 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
    634