1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <utility> 17 18 #include "tensorflow/core/util/tensor_format.h" 19 20 #include "tensorflow/core/platform/logging.h" 21 #include "tensorflow/core/platform/test.h" 22 23 namespace tensorflow { 24 25 #define EnumStringPair(val) \ 26 { val, #val } 27 28 std::pair<TensorFormat, const char*> test_data_formats[] = { 29 EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW), 30 EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W), 31 EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN), 32 }; 33 34 std::pair<FilterTensorFormat, const char*> test_filter_formats[] = { 35 EnumStringPair(FORMAT_HWIO), 36 EnumStringPair(FORMAT_OIHW), 37 EnumStringPair(FORMAT_OIHW_VECT_I), 38 }; 39 40 // This is an alternative way of specifying the tensor dimension indexes for 41 // each tensor format. For now it can be used as a cross-check of the existing 42 // functions, but later could replace them. 43 44 // Represents the dimension indexes of an activations tensor format. 45 struct TensorDimMap { 46 int n() const { return dim_n; } 47 int h() const { return dim_h; } 48 int w() const { return dim_w; } 49 int c() const { return dim_c; } 50 int spatial(int spatial_index) const { return spatial_dim[spatial_index]; } 51 52 int dim_n, dim_h, dim_w, dim_c; 53 int spatial_dim[3]; 54 }; 55 56 // Represents the dimension indexes of a filter tensor format. 57 struct FilterDimMap { 58 int h() const { return dim_h; } 59 int w() const { return dim_w; } 60 int i() const { return dim_i; } 61 int o() const { return dim_o; } 62 int spatial(int spatial_index) const { return spatial_dim[spatial_index]; } 63 64 int dim_h, dim_w, dim_i, dim_o; 65 int spatial_dim[3]; 66 }; 67 68 // clang-format off 69 70 // Predefined constants specifying the actual dimension indexes for each 71 // supported tensor and filter format. 72 struct DimMaps { 73 #define StaCoExTensorDm static constexpr TensorDimMap 74 // 'N', 'H', 'W', 'C' 0, 1, 2 75 StaCoExTensorDm kTdmInvalid = { -1, -1, -1, -1, { -1, -1, -1 } }; 76 // These arrays are indexed by the number of spatial dimensions in the format. 77 StaCoExTensorDm kTdmNHWC[4] = { kTdmInvalid, 78 { 0, -1, 1, 2, { 1, -1, -1 } }, // 1D 79 { 0, 1, 2, 3, { 1, 2, -1 } }, // 2D 80 { 0, 2, 3, 4, { 1, 2, 3 } } // 3D 81 }; 82 StaCoExTensorDm kTdmNCHW[4] = { kTdmInvalid, 83 { 0, -1, 2, 1, { 2, -1, -1 } }, 84 { 0, 2, 3, 1, { 2, 3, -1 } }, 85 { 0, 3, 4, 1, { 2, 3, 4 } } 86 }; 87 StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid, 88 { 1, -1, 0, 2, { 0, -1, -1 } }, 89 { 2, 0, 1, 3, { 0, 1, -1 } }, 90 { 3, 1, 2, 4, { 0, 1, 2 } } 91 }; 92 StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid, 93 { 2, -1, 0, 1, { 0, -1, -1 } }, 94 { 3, 0, 1, 2, { 0, 1, -1 } }, 95 { 4, 1, 2, 3, { 0, 1, 2 } } 96 }; 97 #undef StaCoExTensorDm 98 #define StaCoExFilterDm static constexpr FilterDimMap 99 // 'H', 'W', 'I', 'O' 0 1 2 100 StaCoExFilterDm kFdmInvalid = { -1, -1, -1, -1, { -1, -1, -1 } }; 101 StaCoExFilterDm kFdmHWIO[4] = { kFdmInvalid, 102 { -1, 0, 1, 2, { 0, -1, -1 } }, 103 { 0, 1, 2, 3, { 0, 1, -1 } }, 104 { 1, 2, 3, 4, { 0, 1, 2 } } 105 }; 106 StaCoExFilterDm kFdmOIHW[4] = { kFdmInvalid, 107 { -1, 2, 1, 0, { 2, -1, -1 } }, 108 { 2, 3, 1, 0, { 2, 3, -1 } }, 109 { 3, 4, 1, 0, { 2, 3, 4 } } 110 }; 111 #undef StaCoExFilterDm 112 }; 113 114 inline constexpr const TensorDimMap& 115 GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) { 116 return 117 (format == FORMAT_NHWC || 118 format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] : 119 (format == FORMAT_NCHW || 120 format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] : 121 (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] : 122 (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims] 123 : DimMaps::kTdmInvalid; 124 } 125 126 inline constexpr const FilterDimMap& 127 GetFilterDimMap(const int num_spatial_dims, 128 const FilterTensorFormat format) { 129 return 130 (format == FORMAT_HWIO) ? DimMaps::kFdmHWIO[num_spatial_dims] : 131 (format == FORMAT_OIHW || 132 format == FORMAT_OIHW_VECT_I) ? DimMaps::kFdmOIHW[num_spatial_dims] 133 : DimMaps::kFdmInvalid; 134 } 135 // clang-format on 136 137 constexpr TensorDimMap DimMaps::kTdmInvalid; 138 constexpr TensorDimMap DimMaps::kTdmNHWC[4]; 139 constexpr TensorDimMap DimMaps::kTdmNCHW[4]; 140 constexpr TensorDimMap DimMaps::kTdmHWNC[4]; 141 constexpr TensorDimMap DimMaps::kTdmHWCN[4]; 142 constexpr FilterDimMap DimMaps::kFdmInvalid; 143 constexpr FilterDimMap DimMaps::kFdmHWIO[4]; 144 constexpr FilterDimMap DimMaps::kFdmOIHW[4]; 145 146 TEST(TensorFormatTest, FormatEnumsAndStrings) { 147 const string prefix = "FORMAT_"; 148 for (auto& test_data_format : test_data_formats) { 149 const char* stringified_format_enum = test_data_format.second; 150 LOG(INFO) << stringified_format_enum << " = " << test_data_format.first; 151 string expected_format_str = &stringified_format_enum[prefix.size()]; 152 TensorFormat format; 153 EXPECT_TRUE(FormatFromString(expected_format_str, &format)); 154 string format_str = ToString(format); 155 EXPECT_EQ(expected_format_str, format_str); 156 EXPECT_EQ(test_data_format.first, format); 157 } 158 for (auto& test_filter_format : test_filter_formats) { 159 const char* stringified_format_enum = test_filter_format.second; 160 LOG(INFO) << stringified_format_enum << " = " << test_filter_format.first; 161 string expected_format_str = &stringified_format_enum[prefix.size()]; 162 FilterTensorFormat format; 163 EXPECT_TRUE(FilterFormatFromString(expected_format_str, &format)); 164 string format_str = ToString(format); 165 EXPECT_EQ(expected_format_str, format_str); 166 EXPECT_EQ(test_filter_format.first, format); 167 } 168 } 169 170 template <int num_spatial_dims> 171 void RunDimensionIndexesTest() { 172 for (auto& test_data_format : test_data_formats) { 173 TensorFormat format = test_data_format.first; 174 auto& tdm = GetTensorDimMap(num_spatial_dims, format); 175 int num_dims = GetTensorDimsFromSpatialDims(num_spatial_dims, format); 176 LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims 177 << ", num_dims=" << num_dims; 178 EXPECT_EQ(GetTensorBatchDimIndex(num_dims, format), tdm.n()); 179 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'N'), tdm.n()); 180 EXPECT_EQ(GetTensorFeatureDimIndex(num_dims, format), tdm.c()); 181 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'C'), tdm.c()); 182 for (int i = 0; i < num_spatial_dims; ++i) { 183 EXPECT_EQ(GetTensorSpatialDimIndex(num_dims, format, i), tdm.spatial(i)); 184 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, '0' + i), 185 tdm.spatial(i)); 186 } 187 } 188 for (auto& test_filter_format : test_filter_formats) { 189 FilterTensorFormat format = test_filter_format.first; 190 auto& fdm = GetFilterDimMap(num_spatial_dims, format); 191 int num_dims = GetFilterTensorDimsFromSpatialDims(num_spatial_dims, format); 192 LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims 193 << ", num_dims=" << num_dims; 194 EXPECT_EQ(GetFilterTensorOutputChannelsDimIndex(num_dims, format), fdm.o()); 195 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'O'), fdm.o()); 196 EXPECT_EQ(GetFilterTensorInputChannelsDimIndex(num_dims, format), fdm.i()); 197 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'I'), fdm.i()); 198 for (int i = 0; i < num_spatial_dims; ++i) { 199 EXPECT_EQ(GetFilterTensorSpatialDimIndex(num_dims, format, i), 200 fdm.spatial(i)); 201 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, '0' + i), 202 fdm.spatial(i)); 203 } 204 } 205 } 206 207 TEST(TensorFormatTest, DimensionIndexes) { 208 RunDimensionIndexesTest<1>(); 209 RunDimensionIndexesTest<2>(); 210 RunDimensionIndexesTest<3>(); 211 } 212 213 } // namespace tensorflow 214