Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
     17 #define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
     18 
     19 #include "tensorflow/core/kernels/deep_conv2d.h"
     20 
     21 namespace tensorflow {
     22 
     23 // Winograd DeepConv2DTransform implementation for 3x3 filters.
     24 // Details:
     25 // *) Arithmetic complexity of computations: Shmuel Winograd
     26 // *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray
     27 
     28 template <typename T>
     29 class WinogradTransform : public DeepConv2DTransform<T> {
     30  public:
     31   typedef typename DeepConv2DTransform<T>::Shape Shape;
     32 
     33   WinogradTransform()
     34       : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {}
     35 
     36   virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
     37                                         T* transform_matrix) const;
     38 
     39   virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
     40                                        T* transform_matrix) const;
     41 
     42   virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
     43                                         T* transform_matrix) const;
     44 
     45   virtual const Shape& filter_shape() const { return filter_shape_; }
     46   virtual const Shape& input_shape() const { return input_shape_; }
     47   virtual const Shape& output_shape() const { return output_shape_; }
     48 
     49  private:
     50   const Shape filter_shape_;
     51   const Shape input_shape_;
     52   const Shape output_shape_;
     53 };
     54 
     55 // The filter transform matrix is the kronecker product 'M * M' of the
     56 // following matrix 'M':
     57 //
     58 //   [ 1    0   0   ]
     59 //   [ 1/2  1/2 1/2 ]
     60 //   [ 1/2 -1/2 1/2 ]
     61 //   [ 0    0   1   ]
     62 //
     63 // The data layout of 'transform_matrix':
     64 //   [input_tile_spatial_size, filter_spatial_size]
     65 //
     66 template <typename T>
     67 void WinogradTransform<T>::GetFilterTransformMatrix(const int64 rows,
     68                                                     const int64 cols,
     69                                                     T* transform_matrix) const {
     70   CHECK_GT(rows, 0);
     71   CHECK_GT(cols, 0);
     72   memset(transform_matrix, 0, sizeof(T) * rows * cols);
     73 
     74   // Sub matrix [0,0]
     75   transform_matrix[0 * cols + 0] = T(1.0);
     76 
     77   transform_matrix[1 * cols + 0] = T(0.5);
     78   transform_matrix[1 * cols + 1] = T(0.5);
     79   transform_matrix[1 * cols + 2] = T(0.5);
     80 
     81   transform_matrix[2 * cols + 0] = T(0.5);
     82   transform_matrix[2 * cols + 1] = T(-0.5);
     83   transform_matrix[2 * cols + 2] = T(0.5);
     84 
     85   transform_matrix[3 * cols + 2] = T(1.0);
     86 
     87   // Sub matrix [1,0]
     88   transform_matrix[4 * cols + 0] = T(0.5);
     89 
     90   transform_matrix[5 * cols + 0] = T(0.25);
     91   transform_matrix[5 * cols + 1] = T(0.25);
     92   transform_matrix[5 * cols + 2] = T(0.25);
     93 
     94   transform_matrix[6 * cols + 0] = T(0.25);
     95   transform_matrix[6 * cols + 1] = T(-0.25);
     96   transform_matrix[6 * cols + 2] = T(0.25);
     97 
     98   transform_matrix[7 * cols + 2] = T(0.5);
     99 
    100   // Sub matrix [1,1]
    101   transform_matrix[4 * cols + 3] = T(0.5);
    102 
    103   transform_matrix[5 * cols + 3] = T(0.25);
    104   transform_matrix[5 * cols + 4] = T(0.25);
    105   transform_matrix[5 * cols + 5] = T(0.25);
    106 
    107   transform_matrix[6 * cols + 3] = T(0.25);
    108   transform_matrix[6 * cols + 4] = T(-0.25);
    109   transform_matrix[6 * cols + 5] = T(0.25);
    110 
    111   transform_matrix[7 * cols + 5] = T(0.5);
    112 
    113   // Sub matrix [1,2]
    114   transform_matrix[4 * cols + 6] = T(0.5);
    115 
    116   transform_matrix[5 * cols + 6] = T(0.25);
    117   transform_matrix[5 * cols + 7] = T(0.25);
    118   transform_matrix[5 * cols + 8] = T(0.25);
    119 
    120   transform_matrix[6 * cols + 6] = T(0.25);
    121   transform_matrix[6 * cols + 7] = T(-0.25);
    122   transform_matrix[6 * cols + 8] = T(0.25);
    123 
    124   transform_matrix[7 * cols + 8] = T(0.5);
    125 
    126   // Sub matrix [2,0]
    127   transform_matrix[8 * cols + 0] = T(0.5);
    128 
    129   transform_matrix[9 * cols + 0] = T(0.25);
    130   transform_matrix[9 * cols + 1] = T(0.25);
    131   transform_matrix[9 * cols + 2] = T(0.25);
    132 
    133   transform_matrix[10 * cols + 0] = T(0.25);
    134   transform_matrix[10 * cols + 1] = T(-0.25);
    135   transform_matrix[10 * cols + 2] = T(0.25);
    136 
    137   transform_matrix[11 * cols + 2] = T(0.5);
    138 
    139   // Sub matrix [2,1]
    140   transform_matrix[8 * cols + 3] = T(-0.5);
    141 
    142   transform_matrix[9 * cols + 3] = T(-0.25);
    143   transform_matrix[9 * cols + 4] = T(-0.25);
    144   transform_matrix[9 * cols + 5] = T(-0.25);
    145 
    146   transform_matrix[10 * cols + 3] = T(-0.25);
    147   transform_matrix[10 * cols + 4] = T(0.25);
    148   transform_matrix[10 * cols + 5] = T(-0.25);
    149 
    150   transform_matrix[11 * cols + 5] = T(-0.5);
    151 
    152   // Sub matrix [2,2]
    153   transform_matrix[8 * cols + 6] = T(0.5);
    154 
    155   transform_matrix[9 * cols + 6] = T(0.25);
    156   transform_matrix[9 * cols + 7] = T(0.25);
    157   transform_matrix[9 * cols + 8] = T(0.25);
    158 
    159   transform_matrix[10 * cols + 6] = T(0.25);
    160   transform_matrix[10 * cols + 7] = T(-0.25);
    161   transform_matrix[10 * cols + 8] = T(0.25);
    162 
    163   transform_matrix[11 * cols + 8] = T(0.5);
    164 
    165   // Sub matrix [3,2]
    166   transform_matrix[12 * cols + 6] = T(1.0);
    167 
    168   transform_matrix[13 * cols + 6] = T(0.5);
    169   transform_matrix[13 * cols + 7] = T(0.5);
    170   transform_matrix[13 * cols + 8] = T(0.5);
    171 
    172   transform_matrix[14 * cols + 6] = T(0.5);
    173   transform_matrix[14 * cols + 7] = T(-0.5);
    174   transform_matrix[14 * cols + 8] = T(0.5);
    175 
    176   transform_matrix[15 * cols + 8] = T(1.0);
    177 }
    178 
    179 // The input transform matrix is the kronecker product 'M * M' of the
    180 // following matrix 'M':
    181 //
    182 //   [1   0  -1   0]
    183 //   [0   1   1   0]
    184 //   [0  -1   1   0]
    185 //   [0   1   0  -1]
    186 //
    187 // Data layout of 'transform_matrix':
    188 //   [tile_spatial_size, tile_spatial_size]
    189 //
    190 template <typename T>
    191 void WinogradTransform<T>::GetInputTransformMatrix(const int64 rows,
    192                                                    const int64 cols,
    193                                                    T* transform_matrix) const {
    194   CHECK_GT(rows, 0);
    195   CHECK_GT(cols, 0);
    196   memset(transform_matrix, 0, sizeof(T) * rows * cols);
    197 
    198   // Sub matrix [0,0]
    199   transform_matrix[0 * cols + 0] = T(1.0);
    200   transform_matrix[0 * cols + 2] = T(-1.0);
    201 
    202   transform_matrix[1 * cols + 1] = T(1.0);
    203   transform_matrix[1 * cols + 2] = T(1.0);
    204 
    205   transform_matrix[2 * cols + 1] = T(-1.0);
    206   transform_matrix[2 * cols + 2] = T(1.0);
    207 
    208   transform_matrix[3 * cols + 1] = T(1.0);
    209   transform_matrix[3 * cols + 3] = T(-1.0);
    210 
    211   // Sub matrix [0,2]
    212   transform_matrix[0 * cols + 8] = T(-1.0);
    213   transform_matrix[0 * cols + 10] = T(1.0);
    214 
    215   transform_matrix[1 * cols + 9] = T(-1.0);
    216   transform_matrix[1 * cols + 10] = T(-1.0);
    217 
    218   transform_matrix[2 * cols + 9] = T(1.0);
    219   transform_matrix[2 * cols + 10] = T(-1.0);
    220 
    221   transform_matrix[3 * cols + 9] = T(-1.0);
    222   transform_matrix[3 * cols + 11] = T(1.0);
    223 
    224   // Sub matrix [1,1]
    225   transform_matrix[4 * cols + 4] = T(1.0);
    226   transform_matrix[4 * cols + 6] = T(-1.0);
    227 
    228   transform_matrix[5 * cols + 5] = T(1.0);
    229   transform_matrix[5 * cols + 6] = T(1.0);
    230 
    231   transform_matrix[6 * cols + 5] = T(-1.0);
    232   transform_matrix[6 * cols + 6] = T(1.0);
    233 
    234   transform_matrix[7 * cols + 5] = T(1.0);
    235   transform_matrix[7 * cols + 7] = T(-1.0);
    236 
    237   // Sub matrix [1,2]
    238   transform_matrix[4 * cols + 8] = T(1.0);
    239   transform_matrix[4 * cols + 10] = T(-1.0);
    240 
    241   transform_matrix[5 * cols + 9] = T(1.0);
    242   transform_matrix[5 * cols + 10] = T(1.0);
    243 
    244   transform_matrix[6 * cols + 9] = T(-1.0);
    245   transform_matrix[6 * cols + 10] = T(1.0);
    246 
    247   transform_matrix[7 * cols + 9] = T(1.0);
    248   transform_matrix[7 * cols + 11] = T(-1.0);
    249 
    250   // Sub matrix [2,1]
    251   transform_matrix[8 * cols + 4] = T(-1.0);
    252   transform_matrix[8 * cols + 6] = T(1.0);
    253 
    254   transform_matrix[9 * cols + 5] = T(-1.0);
    255   transform_matrix[9 * cols + 6] = T(-1.0);
    256 
    257   transform_matrix[10 * cols + 5] = T(1.0);
    258   transform_matrix[10 * cols + 6] = T(-1.0);
    259 
    260   transform_matrix[11 * cols + 5] = T(-1.0);
    261   transform_matrix[11 * cols + 7] = T(1.0);
    262 
    263   // Sub matrix [2,2]
    264   transform_matrix[8 * cols + 8] = T(1.0);
    265   transform_matrix[8 * cols + 10] = T(-1.0);
    266 
    267   transform_matrix[9 * cols + 9] = T(1.0);
    268   transform_matrix[9 * cols + 10] = T(1.0);
    269 
    270   transform_matrix[10 * cols + 9] = T(-1.0);
    271   transform_matrix[10 * cols + 10] = T(1.0);
    272 
    273   transform_matrix[11 * cols + 9] = T(1.0);
    274   transform_matrix[11 * cols + 11] = T(-1.0);
    275 
    276   // Sub matrix [3,1]
    277   transform_matrix[12 * cols + 4] = T(1.0);
    278   transform_matrix[12 * cols + 6] = T(-1.0);
    279 
    280   transform_matrix[13 * cols + 5] = T(1.0);
    281   transform_matrix[13 * cols + 6] = T(1.0);
    282 
    283   transform_matrix[14 * cols + 5] = T(-1.0);
    284   transform_matrix[14 * cols + 6] = T(1.0);
    285 
    286   transform_matrix[15 * cols + 5] = T(1.0);
    287   transform_matrix[15 * cols + 7] = T(-1.0);
    288 
    289   // Sub matrix [3,3]
    290   transform_matrix[12 * cols + 12] = T(-1.0);
    291   transform_matrix[12 * cols + 14] = T(1.0);
    292 
    293   transform_matrix[13 * cols + 13] = T(-1.0);
    294   transform_matrix[13 * cols + 14] = T(-1.0);
    295 
    296   transform_matrix[14 * cols + 13] = T(1.0);
    297   transform_matrix[14 * cols + 14] = T(-1.0);
    298 
    299   transform_matrix[15 * cols + 13] = T(-1.0);
    300   transform_matrix[15 * cols + 15] = T(1.0);
    301 };
    302 
    303 // The output transform matrix is the kronecker product 'M * M' of the
    304 // following matrix 'M':
    305 //
    306 //   [1  1  1  0]
    307 //   [0  1 -1 -1]
    308 //
    309 // Data layout of 'transform_matrix':
    310 //   [out_tile_spatial_size, tile_spatial_size]
    311 //
    312 template <typename T>
    313 void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows,
    314                                                     const int64 cols,
    315                                                     T* transform_matrix) const {
    316   CHECK_GT(rows, 0);
    317   CHECK_GT(cols, 0);
    318   memset(transform_matrix, 0, sizeof(T) * rows * cols);
    319 
    320   // Sub matrix [0,0]
    321   transform_matrix[0 * cols + 0] = T(1.0);
    322   transform_matrix[0 * cols + 1] = T(1.0);
    323   transform_matrix[0 * cols + 2] = T(1.0);
    324 
    325   transform_matrix[1 * cols + 1] = T(1.0);
    326   transform_matrix[1 * cols + 2] = T(-1.0);
    327   transform_matrix[1 * cols + 3] = T(-1.0);
    328 
    329   // Sub matrix [0,1]
    330   transform_matrix[0 * cols + 4] = T(1.0);
    331   transform_matrix[0 * cols + 5] = T(1.0);
    332   transform_matrix[0 * cols + 6] = T(1.0);
    333 
    334   transform_matrix[1 * cols + 5] = T(1.0);
    335   transform_matrix[1 * cols + 6] = T(-1.0);
    336   transform_matrix[1 * cols + 7] = T(-1.0);
    337 
    338   // Sub matrix [0,2]
    339   transform_matrix[0 * cols + 8] = T(1.0);
    340   transform_matrix[0 * cols + 9] = T(1.0);
    341   transform_matrix[0 * cols + 10] = T(1.0);
    342 
    343   transform_matrix[1 * cols + 9] = T(1.0);
    344   transform_matrix[1 * cols + 10] = T(-1.0);
    345   transform_matrix[1 * cols + 11] = T(-1.0);
    346 
    347   // Sub matrix [1,1]
    348   transform_matrix[2 * cols + 4] = T(1.0);
    349   transform_matrix[2 * cols + 5] = T(1.0);
    350   transform_matrix[2 * cols + 6] = T(1.0);
    351 
    352   transform_matrix[3 * cols + 5] = T(1.0);
    353   transform_matrix[3 * cols + 6] = T(-1.0);
    354   transform_matrix[3 * cols + 7] = T(-1.0);
    355 
    356   // Sub matrix [1,2]
    357   transform_matrix[2 * cols + 8] = T(-1.0);
    358   transform_matrix[2 * cols + 9] = T(-1.0);
    359   transform_matrix[2 * cols + 10] = T(-1.0);
    360 
    361   transform_matrix[3 * cols + 9] = T(-1.0);
    362   transform_matrix[3 * cols + 10] = T(1.0);
    363   transform_matrix[3 * cols + 11] = T(1.0);
    364 
    365   // Sub matrix [1,3]
    366   transform_matrix[2 * cols + 12] = T(-1.0);
    367   transform_matrix[2 * cols + 13] = T(-1.0);
    368   transform_matrix[2 * cols + 14] = T(-1.0);
    369 
    370   transform_matrix[3 * cols + 13] = T(-1.0);
    371   transform_matrix[3 * cols + 14] = T(1.0);
    372   transform_matrix[3 * cols + 15] = T(1.0);
    373 };
    374 
    375 }  // namespace tensorflow
    376 
    377 #endif  // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
    378