Home | History | Annotate | Download | only in util
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <utility>
     17 
     18 #include "tensorflow/core/util/tensor_format.h"
     19 
     20 #include "tensorflow/core/platform/logging.h"
     21 #include "tensorflow/core/platform/test.h"
     22 
     23 namespace tensorflow {
     24 
     25 #define EnumStringPair(val) \
     26   { val, #val }
     27 
     28 std::pair<TensorFormat, const char*> test_data_formats[] = {
     29     EnumStringPair(FORMAT_NHWC),        EnumStringPair(FORMAT_NCHW),
     30     EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
     31     EnumStringPair(FORMAT_HWNC),        EnumStringPair(FORMAT_HWCN),
     32 };
     33 
     34 std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
     35     EnumStringPair(FORMAT_HWIO),
     36     EnumStringPair(FORMAT_OIHW),
     37     EnumStringPair(FORMAT_OIHW_VECT_I),
     38 };
     39 
     40 // This is an alternative way of specifying the tensor dimension indexes for
     41 // each tensor format. For now it can be used as a cross-check of the existing
     42 // functions, but later could replace them.
     43 
     44 // Represents the dimension indexes of an activations tensor format.
     45 struct TensorDimMap {
     46   int n() const { return dim_n; }
     47   int h() const { return dim_h; }
     48   int w() const { return dim_w; }
     49   int c() const { return dim_c; }
     50   int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
     51 
     52   int dim_n, dim_h, dim_w, dim_c;
     53   int spatial_dim[3];
     54 };
     55 
     56 // Represents the dimension indexes of a filter tensor format.
     57 struct FilterDimMap {
     58   int h() const { return dim_h; }
     59   int w() const { return dim_w; }
     60   int i() const { return dim_i; }
     61   int o() const { return dim_o; }
     62   int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
     63 
     64   int dim_h, dim_w, dim_i, dim_o;
     65   int spatial_dim[3];
     66 };
     67 
     68 // clang-format off
     69 
     70 // Predefined constants specifying the actual dimension indexes for each
     71 // supported tensor and filter format.
     72 struct DimMaps {
     73 #define StaCoExTensorDm static constexpr TensorDimMap
     74   //                                'N', 'H', 'W', 'C'    0,  1,  2
     75   StaCoExTensorDm kTdmInvalid =   { -1,  -1,  -1,  -1, { -1, -1, -1 } };
     76   // These arrays are indexed by the number of spatial dimensions in the format.
     77   StaCoExTensorDm kTdmNHWC[4] = { kTdmInvalid,
     78                                   {  0,  -1,   1,   2, {  1, -1, -1 } },  // 1D
     79                                   {  0,   1,   2,   3, {  1,  2, -1 } },  // 2D
     80                                   {  0,   2,   3,   4, {  1,  2,  3 } }   // 3D
     81                                 };
     82   StaCoExTensorDm kTdmNCHW[4] = { kTdmInvalid,
     83                                   {  0,  -1,   2,   1, {  2, -1, -1 } },
     84                                   {  0,   2,   3,   1, {  2,  3, -1 } },
     85                                   {  0,   3,   4,   1, {  2,  3,  4 } }
     86                                 };
     87   StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
     88                                   {  1,  -1,   0,   2, {  0, -1, -1 } },
     89                                   {  2,   0,   1,   3, {  0,  1, -1 } },
     90                                   {  3,   1,   2,   4, {  0,  1,  2 } }
     91                                 };
     92   StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
     93                                   {  2,  -1,   0,   1, {  0, -1, -1 } },
     94                                   {  3,   0,   1,   2, {  0,  1, -1 } },
     95                                   {  4,   1,   2,   3, {  0,  1,  2 } }
     96                                 };
     97 #undef StaCoExTensorDm
     98 #define StaCoExFilterDm static constexpr FilterDimMap
     99   //                                'H', 'W', 'I', 'O'    0   1   2
    100   StaCoExFilterDm kFdmInvalid =   { -1,  -1,  -1,  -1, { -1, -1, -1 } };
    101   StaCoExFilterDm kFdmHWIO[4] = { kFdmInvalid,
    102                                   { -1,   0,   1,   2, {  0, -1, -1 } },
    103                                   {  0,   1,   2,   3, {  0,  1, -1 } },
    104                                   {  1,   2,   3,   4, {  0,  1,  2 } }
    105                                 };
    106   StaCoExFilterDm kFdmOIHW[4] = { kFdmInvalid,
    107                                   { -1,   2,   1,   0, {  2, -1, -1 } },
    108                                   {  2,   3,   1,   0, {  2,  3, -1 } },
    109                                   {  3,   4,   1,   0, {  2,  3,  4 } }
    110                                 };
    111 #undef StaCoExFilterDm
    112 };
    113 
    114 inline constexpr const TensorDimMap&
    115 GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
    116   return
    117       (format == FORMAT_NHWC ||
    118        format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
    119       (format == FORMAT_NCHW ||
    120        format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
    121       (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
    122       (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
    123                               : DimMaps::kTdmInvalid;
    124 }
    125 
    126 inline constexpr const FilterDimMap&
    127 GetFilterDimMap(const int num_spatial_dims,
    128                 const FilterTensorFormat format) {
    129   return
    130       (format == FORMAT_HWIO) ? DimMaps::kFdmHWIO[num_spatial_dims] :
    131       (format == FORMAT_OIHW ||
    132        format == FORMAT_OIHW_VECT_I) ? DimMaps::kFdmOIHW[num_spatial_dims]
    133                                      : DimMaps::kFdmInvalid;
    134 }
    135 // clang-format on
    136 
    137 constexpr TensorDimMap DimMaps::kTdmInvalid;
    138 constexpr TensorDimMap DimMaps::kTdmNHWC[4];
    139 constexpr TensorDimMap DimMaps::kTdmNCHW[4];
    140 constexpr TensorDimMap DimMaps::kTdmHWNC[4];
    141 constexpr TensorDimMap DimMaps::kTdmHWCN[4];
    142 constexpr FilterDimMap DimMaps::kFdmInvalid;
    143 constexpr FilterDimMap DimMaps::kFdmHWIO[4];
    144 constexpr FilterDimMap DimMaps::kFdmOIHW[4];
    145 
    146 TEST(TensorFormatTest, FormatEnumsAndStrings) {
    147   const string prefix = "FORMAT_";
    148   for (auto& test_data_format : test_data_formats) {
    149     const char* stringified_format_enum = test_data_format.second;
    150     LOG(INFO) << stringified_format_enum << " = " << test_data_format.first;
    151     string expected_format_str = &stringified_format_enum[prefix.size()];
    152     TensorFormat format;
    153     EXPECT_TRUE(FormatFromString(expected_format_str, &format));
    154     string format_str = ToString(format);
    155     EXPECT_EQ(expected_format_str, format_str);
    156     EXPECT_EQ(test_data_format.first, format);
    157   }
    158   for (auto& test_filter_format : test_filter_formats) {
    159     const char* stringified_format_enum = test_filter_format.second;
    160     LOG(INFO) << stringified_format_enum << " = " << test_filter_format.first;
    161     string expected_format_str = &stringified_format_enum[prefix.size()];
    162     FilterTensorFormat format;
    163     EXPECT_TRUE(FilterFormatFromString(expected_format_str, &format));
    164     string format_str = ToString(format);
    165     EXPECT_EQ(expected_format_str, format_str);
    166     EXPECT_EQ(test_filter_format.first, format);
    167   }
    168 }
    169 
    170 template <int num_spatial_dims>
    171 void RunDimensionIndexesTest() {
    172   for (auto& test_data_format : test_data_formats) {
    173     TensorFormat format = test_data_format.first;
    174     auto& tdm = GetTensorDimMap(num_spatial_dims, format);
    175     int num_dims = GetTensorDimsFromSpatialDims(num_spatial_dims, format);
    176     LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
    177               << ", num_dims=" << num_dims;
    178     EXPECT_EQ(GetTensorBatchDimIndex(num_dims, format), tdm.n());
    179     EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'N'), tdm.n());
    180     EXPECT_EQ(GetTensorFeatureDimIndex(num_dims, format), tdm.c());
    181     EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'C'), tdm.c());
    182     for (int i = 0; i < num_spatial_dims; ++i) {
    183       EXPECT_EQ(GetTensorSpatialDimIndex(num_dims, format, i), tdm.spatial(i));
    184       EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, '0' + i),
    185                 tdm.spatial(i));
    186     }
    187   }
    188   for (auto& test_filter_format : test_filter_formats) {
    189     FilterTensorFormat format = test_filter_format.first;
    190     auto& fdm = GetFilterDimMap(num_spatial_dims, format);
    191     int num_dims = GetFilterTensorDimsFromSpatialDims(num_spatial_dims, format);
    192     LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
    193               << ", num_dims=" << num_dims;
    194     EXPECT_EQ(GetFilterTensorOutputChannelsDimIndex(num_dims, format), fdm.o());
    195     EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'O'), fdm.o());
    196     EXPECT_EQ(GetFilterTensorInputChannelsDimIndex(num_dims, format), fdm.i());
    197     EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'I'), fdm.i());
    198     for (int i = 0; i < num_spatial_dims; ++i) {
    199       EXPECT_EQ(GetFilterTensorSpatialDimIndex(num_dims, format, i),
    200                 fdm.spatial(i));
    201       EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, '0' + i),
    202                 fdm.spatial(i));
    203     }
    204   }
    205 }
    206 
    207 TEST(TensorFormatTest, DimensionIndexes) {
    208   RunDimensionIndexesTest<1>();
    209   RunDimensionIndexesTest<2>();
    210   RunDimensionIndexesTest<3>();
    211 }
    212 
    213 }  // namespace tensorflow
    214