1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ 17 #define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ 18 19 #include "tensorflow/core/kernels/deep_conv2d.h" 20 21 namespace tensorflow { 22 23 // Winograd DeepConv2DTransform implementation for 3x3 filters. 24 // Details: 25 // *) Arithmetic complexity of computations: Shmuel Winograd 26 // *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray 27 28 template <typename T> 29 class WinogradTransform : public DeepConv2DTransform<T> { 30 public: 31 typedef typename DeepConv2DTransform<T>::Shape Shape; 32 33 WinogradTransform() 34 : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {} 35 36 virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols, 37 T* transform_matrix) const; 38 39 virtual void GetInputTransformMatrix(const int64 rows, const int64 cols, 40 T* transform_matrix) const; 41 42 virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols, 43 T* transform_matrix) const; 44 45 virtual const Shape& filter_shape() const { return filter_shape_; } 46 virtual const Shape& input_shape() const { return input_shape_; } 47 virtual const Shape& output_shape() const { return output_shape_; } 48 49 private: 50 const Shape filter_shape_; 51 const Shape input_shape_; 52 const Shape output_shape_; 53 }; 54 55 // The filter transform matrix is the kronecker product 'M * M' of the 56 // following matrix 'M': 57 // 58 // [ 1 0 0 ] 59 // [ 1/2 1/2 1/2 ] 60 // [ 1/2 -1/2 1/2 ] 61 // [ 0 0 1 ] 62 // 63 // The data layout of 'transform_matrix': 64 // [input_tile_spatial_size, filter_spatial_size] 65 // 66 template <typename T> 67 void WinogradTransform<T>::GetFilterTransformMatrix(const int64 rows, 68 const int64 cols, 69 T* transform_matrix) const { 70 CHECK_GT(rows, 0); 71 CHECK_GT(cols, 0); 72 memset(transform_matrix, 0, sizeof(T) * rows * cols); 73 74 // Sub matrix [0,0] 75 transform_matrix[0 * cols + 0] = T(1.0); 76 77 transform_matrix[1 * cols + 0] = T(0.5); 78 transform_matrix[1 * cols + 1] = T(0.5); 79 transform_matrix[1 * cols + 2] = T(0.5); 80 81 transform_matrix[2 * cols + 0] = T(0.5); 82 transform_matrix[2 * cols + 1] = T(-0.5); 83 transform_matrix[2 * cols + 2] = T(0.5); 84 85 transform_matrix[3 * cols + 2] = T(1.0); 86 87 // Sub matrix [1,0] 88 transform_matrix[4 * cols + 0] = T(0.5); 89 90 transform_matrix[5 * cols + 0] = T(0.25); 91 transform_matrix[5 * cols + 1] = T(0.25); 92 transform_matrix[5 * cols + 2] = T(0.25); 93 94 transform_matrix[6 * cols + 0] = T(0.25); 95 transform_matrix[6 * cols + 1] = T(-0.25); 96 transform_matrix[6 * cols + 2] = T(0.25); 97 98 transform_matrix[7 * cols + 2] = T(0.5); 99 100 // Sub matrix [1,1] 101 transform_matrix[4 * cols + 3] = T(0.5); 102 103 transform_matrix[5 * cols + 3] = T(0.25); 104 transform_matrix[5 * cols + 4] = T(0.25); 105 transform_matrix[5 * cols + 5] = T(0.25); 106 107 transform_matrix[6 * cols + 3] = T(0.25); 108 transform_matrix[6 * cols + 4] = T(-0.25); 109 transform_matrix[6 * cols + 5] = T(0.25); 110 111 transform_matrix[7 * cols + 5] = T(0.5); 112 113 // Sub matrix [1,2] 114 transform_matrix[4 * cols + 6] = T(0.5); 115 116 transform_matrix[5 * cols + 6] = T(0.25); 117 transform_matrix[5 * cols + 7] = T(0.25); 118 transform_matrix[5 * cols + 8] = T(0.25); 119 120 transform_matrix[6 * cols + 6] = T(0.25); 121 transform_matrix[6 * cols + 7] = T(-0.25); 122 transform_matrix[6 * cols + 8] = T(0.25); 123 124 transform_matrix[7 * cols + 8] = T(0.5); 125 126 // Sub matrix [2,0] 127 transform_matrix[8 * cols + 0] = T(0.5); 128 129 transform_matrix[9 * cols + 0] = T(0.25); 130 transform_matrix[9 * cols + 1] = T(0.25); 131 transform_matrix[9 * cols + 2] = T(0.25); 132 133 transform_matrix[10 * cols + 0] = T(0.25); 134 transform_matrix[10 * cols + 1] = T(-0.25); 135 transform_matrix[10 * cols + 2] = T(0.25); 136 137 transform_matrix[11 * cols + 2] = T(0.5); 138 139 // Sub matrix [2,1] 140 transform_matrix[8 * cols + 3] = T(-0.5); 141 142 transform_matrix[9 * cols + 3] = T(-0.25); 143 transform_matrix[9 * cols + 4] = T(-0.25); 144 transform_matrix[9 * cols + 5] = T(-0.25); 145 146 transform_matrix[10 * cols + 3] = T(-0.25); 147 transform_matrix[10 * cols + 4] = T(0.25); 148 transform_matrix[10 * cols + 5] = T(-0.25); 149 150 transform_matrix[11 * cols + 5] = T(-0.5); 151 152 // Sub matrix [2,2] 153 transform_matrix[8 * cols + 6] = T(0.5); 154 155 transform_matrix[9 * cols + 6] = T(0.25); 156 transform_matrix[9 * cols + 7] = T(0.25); 157 transform_matrix[9 * cols + 8] = T(0.25); 158 159 transform_matrix[10 * cols + 6] = T(0.25); 160 transform_matrix[10 * cols + 7] = T(-0.25); 161 transform_matrix[10 * cols + 8] = T(0.25); 162 163 transform_matrix[11 * cols + 8] = T(0.5); 164 165 // Sub matrix [3,2] 166 transform_matrix[12 * cols + 6] = T(1.0); 167 168 transform_matrix[13 * cols + 6] = T(0.5); 169 transform_matrix[13 * cols + 7] = T(0.5); 170 transform_matrix[13 * cols + 8] = T(0.5); 171 172 transform_matrix[14 * cols + 6] = T(0.5); 173 transform_matrix[14 * cols + 7] = T(-0.5); 174 transform_matrix[14 * cols + 8] = T(0.5); 175 176 transform_matrix[15 * cols + 8] = T(1.0); 177 } 178 179 // The input transform matrix is the kronecker product 'M * M' of the 180 // following matrix 'M': 181 // 182 // [1 0 -1 0] 183 // [0 1 1 0] 184 // [0 -1 1 0] 185 // [0 1 0 -1] 186 // 187 // Data layout of 'transform_matrix': 188 // [tile_spatial_size, tile_spatial_size] 189 // 190 template <typename T> 191 void WinogradTransform<T>::GetInputTransformMatrix(const int64 rows, 192 const int64 cols, 193 T* transform_matrix) const { 194 CHECK_GT(rows, 0); 195 CHECK_GT(cols, 0); 196 memset(transform_matrix, 0, sizeof(T) * rows * cols); 197 198 // Sub matrix [0,0] 199 transform_matrix[0 * cols + 0] = T(1.0); 200 transform_matrix[0 * cols + 2] = T(-1.0); 201 202 transform_matrix[1 * cols + 1] = T(1.0); 203 transform_matrix[1 * cols + 2] = T(1.0); 204 205 transform_matrix[2 * cols + 1] = T(-1.0); 206 transform_matrix[2 * cols + 2] = T(1.0); 207 208 transform_matrix[3 * cols + 1] = T(1.0); 209 transform_matrix[3 * cols + 3] = T(-1.0); 210 211 // Sub matrix [0,2] 212 transform_matrix[0 * cols + 8] = T(-1.0); 213 transform_matrix[0 * cols + 10] = T(1.0); 214 215 transform_matrix[1 * cols + 9] = T(-1.0); 216 transform_matrix[1 * cols + 10] = T(-1.0); 217 218 transform_matrix[2 * cols + 9] = T(1.0); 219 transform_matrix[2 * cols + 10] = T(-1.0); 220 221 transform_matrix[3 * cols + 9] = T(-1.0); 222 transform_matrix[3 * cols + 11] = T(1.0); 223 224 // Sub matrix [1,1] 225 transform_matrix[4 * cols + 4] = T(1.0); 226 transform_matrix[4 * cols + 6] = T(-1.0); 227 228 transform_matrix[5 * cols + 5] = T(1.0); 229 transform_matrix[5 * cols + 6] = T(1.0); 230 231 transform_matrix[6 * cols + 5] = T(-1.0); 232 transform_matrix[6 * cols + 6] = T(1.0); 233 234 transform_matrix[7 * cols + 5] = T(1.0); 235 transform_matrix[7 * cols + 7] = T(-1.0); 236 237 // Sub matrix [1,2] 238 transform_matrix[4 * cols + 8] = T(1.0); 239 transform_matrix[4 * cols + 10] = T(-1.0); 240 241 transform_matrix[5 * cols + 9] = T(1.0); 242 transform_matrix[5 * cols + 10] = T(1.0); 243 244 transform_matrix[6 * cols + 9] = T(-1.0); 245 transform_matrix[6 * cols + 10] = T(1.0); 246 247 transform_matrix[7 * cols + 9] = T(1.0); 248 transform_matrix[7 * cols + 11] = T(-1.0); 249 250 // Sub matrix [2,1] 251 transform_matrix[8 * cols + 4] = T(-1.0); 252 transform_matrix[8 * cols + 6] = T(1.0); 253 254 transform_matrix[9 * cols + 5] = T(-1.0); 255 transform_matrix[9 * cols + 6] = T(-1.0); 256 257 transform_matrix[10 * cols + 5] = T(1.0); 258 transform_matrix[10 * cols + 6] = T(-1.0); 259 260 transform_matrix[11 * cols + 5] = T(-1.0); 261 transform_matrix[11 * cols + 7] = T(1.0); 262 263 // Sub matrix [2,2] 264 transform_matrix[8 * cols + 8] = T(1.0); 265 transform_matrix[8 * cols + 10] = T(-1.0); 266 267 transform_matrix[9 * cols + 9] = T(1.0); 268 transform_matrix[9 * cols + 10] = T(1.0); 269 270 transform_matrix[10 * cols + 9] = T(-1.0); 271 transform_matrix[10 * cols + 10] = T(1.0); 272 273 transform_matrix[11 * cols + 9] = T(1.0); 274 transform_matrix[11 * cols + 11] = T(-1.0); 275 276 // Sub matrix [3,1] 277 transform_matrix[12 * cols + 4] = T(1.0); 278 transform_matrix[12 * cols + 6] = T(-1.0); 279 280 transform_matrix[13 * cols + 5] = T(1.0); 281 transform_matrix[13 * cols + 6] = T(1.0); 282 283 transform_matrix[14 * cols + 5] = T(-1.0); 284 transform_matrix[14 * cols + 6] = T(1.0); 285 286 transform_matrix[15 * cols + 5] = T(1.0); 287 transform_matrix[15 * cols + 7] = T(-1.0); 288 289 // Sub matrix [3,3] 290 transform_matrix[12 * cols + 12] = T(-1.0); 291 transform_matrix[12 * cols + 14] = T(1.0); 292 293 transform_matrix[13 * cols + 13] = T(-1.0); 294 transform_matrix[13 * cols + 14] = T(-1.0); 295 296 transform_matrix[14 * cols + 13] = T(1.0); 297 transform_matrix[14 * cols + 14] = T(-1.0); 298 299 transform_matrix[15 * cols + 13] = T(-1.0); 300 transform_matrix[15 * cols + 15] = T(1.0); 301 }; 302 303 // The output transform matrix is the kronecker product 'M * M' of the 304 // following matrix 'M': 305 // 306 // [1 1 1 0] 307 // [0 1 -1 -1] 308 // 309 // Data layout of 'transform_matrix': 310 // [out_tile_spatial_size, tile_spatial_size] 311 // 312 template <typename T> 313 void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows, 314 const int64 cols, 315 T* transform_matrix) const { 316 CHECK_GT(rows, 0); 317 CHECK_GT(cols, 0); 318 memset(transform_matrix, 0, sizeof(T) * rows * cols); 319 320 // Sub matrix [0,0] 321 transform_matrix[0 * cols + 0] = T(1.0); 322 transform_matrix[0 * cols + 1] = T(1.0); 323 transform_matrix[0 * cols + 2] = T(1.0); 324 325 transform_matrix[1 * cols + 1] = T(1.0); 326 transform_matrix[1 * cols + 2] = T(-1.0); 327 transform_matrix[1 * cols + 3] = T(-1.0); 328 329 // Sub matrix [0,1] 330 transform_matrix[0 * cols + 4] = T(1.0); 331 transform_matrix[0 * cols + 5] = T(1.0); 332 transform_matrix[0 * cols + 6] = T(1.0); 333 334 transform_matrix[1 * cols + 5] = T(1.0); 335 transform_matrix[1 * cols + 6] = T(-1.0); 336 transform_matrix[1 * cols + 7] = T(-1.0); 337 338 // Sub matrix [0,2] 339 transform_matrix[0 * cols + 8] = T(1.0); 340 transform_matrix[0 * cols + 9] = T(1.0); 341 transform_matrix[0 * cols + 10] = T(1.0); 342 343 transform_matrix[1 * cols + 9] = T(1.0); 344 transform_matrix[1 * cols + 10] = T(-1.0); 345 transform_matrix[1 * cols + 11] = T(-1.0); 346 347 // Sub matrix [1,1] 348 transform_matrix[2 * cols + 4] = T(1.0); 349 transform_matrix[2 * cols + 5] = T(1.0); 350 transform_matrix[2 * cols + 6] = T(1.0); 351 352 transform_matrix[3 * cols + 5] = T(1.0); 353 transform_matrix[3 * cols + 6] = T(-1.0); 354 transform_matrix[3 * cols + 7] = T(-1.0); 355 356 // Sub matrix [1,2] 357 transform_matrix[2 * cols + 8] = T(-1.0); 358 transform_matrix[2 * cols + 9] = T(-1.0); 359 transform_matrix[2 * cols + 10] = T(-1.0); 360 361 transform_matrix[3 * cols + 9] = T(-1.0); 362 transform_matrix[3 * cols + 10] = T(1.0); 363 transform_matrix[3 * cols + 11] = T(1.0); 364 365 // Sub matrix [1,3] 366 transform_matrix[2 * cols + 12] = T(-1.0); 367 transform_matrix[2 * cols + 13] = T(-1.0); 368 transform_matrix[2 * cols + 14] = T(-1.0); 369 370 transform_matrix[3 * cols + 13] = T(-1.0); 371 transform_matrix[3 * cols + 14] = T(1.0); 372 transform_matrix[3 * cols + 15] = T(1.0); 373 }; 374 375 } // namespace tensorflow 376 377 #endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ 378