Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 
     16 template<int DataLayout>
     17 static void test_simple_patch()
     18 {
     19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     20   tensor.setRandom();
     21   array<ptrdiff_t, 4> patch_dims;
     22 
     23   patch_dims[0] = 1;
     24   patch_dims[1] = 1;
     25   patch_dims[2] = 1;
     26   patch_dims[3] = 1;
     27 
     28   Tensor<float, 5, DataLayout> no_patch;
     29   no_patch = tensor.extract_patches(patch_dims);
     30 
     31   if (DataLayout == ColMajor) {
     32     VERIFY_IS_EQUAL(no_patch.dimension(0), 1);
     33     VERIFY_IS_EQUAL(no_patch.dimension(1), 1);
     34     VERIFY_IS_EQUAL(no_patch.dimension(2), 1);
     35     VERIFY_IS_EQUAL(no_patch.dimension(3), 1);
     36     VERIFY_IS_EQUAL(no_patch.dimension(4), tensor.size());
     37   } else {
     38     VERIFY_IS_EQUAL(no_patch.dimension(0), tensor.size());
     39     VERIFY_IS_EQUAL(no_patch.dimension(1), 1);
     40     VERIFY_IS_EQUAL(no_patch.dimension(2), 1);
     41     VERIFY_IS_EQUAL(no_patch.dimension(3), 1);
     42     VERIFY_IS_EQUAL(no_patch.dimension(4), 1);
     43   }
     44 
     45   for (int i = 0; i < tensor.size(); ++i) {
     46     VERIFY_IS_EQUAL(tensor.data()[i], no_patch.data()[i]);
     47   }
     48 
     49   patch_dims[0] = 2;
     50   patch_dims[1] = 3;
     51   patch_dims[2] = 5;
     52   patch_dims[3] = 7;
     53   Tensor<float, 5, DataLayout> single_patch;
     54   single_patch = tensor.extract_patches(patch_dims);
     55 
     56   if (DataLayout == ColMajor) {
     57     VERIFY_IS_EQUAL(single_patch.dimension(0), 2);
     58     VERIFY_IS_EQUAL(single_patch.dimension(1), 3);
     59     VERIFY_IS_EQUAL(single_patch.dimension(2), 5);
     60     VERIFY_IS_EQUAL(single_patch.dimension(3), 7);
     61     VERIFY_IS_EQUAL(single_patch.dimension(4), 1);
     62   } else {
     63     VERIFY_IS_EQUAL(single_patch.dimension(0), 1);
     64     VERIFY_IS_EQUAL(single_patch.dimension(1), 2);
     65     VERIFY_IS_EQUAL(single_patch.dimension(2), 3);
     66     VERIFY_IS_EQUAL(single_patch.dimension(3), 5);
     67     VERIFY_IS_EQUAL(single_patch.dimension(4), 7);
     68   }
     69 
     70   for (int i = 0; i < tensor.size(); ++i) {
     71     VERIFY_IS_EQUAL(tensor.data()[i], single_patch.data()[i]);
     72   }
     73 
     74   patch_dims[0] = 1;
     75   patch_dims[1] = 2;
     76   patch_dims[2] = 2;
     77   patch_dims[3] = 1;
     78   Tensor<float, 5, DataLayout> twod_patch;
     79   twod_patch = tensor.extract_patches(patch_dims);
     80 
     81   if (DataLayout == ColMajor) {
     82     VERIFY_IS_EQUAL(twod_patch.dimension(0), 1);
     83     VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
     84     VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
     85     VERIFY_IS_EQUAL(twod_patch.dimension(3), 1);
     86     VERIFY_IS_EQUAL(twod_patch.dimension(4), 2*2*4*7);
     87   } else {
     88     VERIFY_IS_EQUAL(twod_patch.dimension(0), 2*2*4*7);
     89     VERIFY_IS_EQUAL(twod_patch.dimension(1), 1);
     90     VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
     91     VERIFY_IS_EQUAL(twod_patch.dimension(3), 2);
     92     VERIFY_IS_EQUAL(twod_patch.dimension(4), 1);
     93   }
     94 
     95   for (int i = 0; i < 2; ++i) {
     96     for (int j = 0; j < 2; ++j) {
     97       for (int k = 0; k < 4; ++k) {
     98         for (int l = 0; l < 7; ++l) {
     99           int patch_loc;
    100           if (DataLayout == ColMajor) {
    101             patch_loc = i + 2 * (j + 2 * (k + 4 * l));
    102           } else {
    103             patch_loc = l + 7 * (k + 4 * (j + 2 * i));
    104           }
    105           for (int x = 0; x < 2; ++x) {
    106             for (int y = 0; y < 2; ++y) {
    107               if (DataLayout == ColMajor) {
    108                 VERIFY_IS_EQUAL(tensor(i,j+x,k+y,l), twod_patch(0,x,y,0,patch_loc));
    109               } else {
    110                 VERIFY_IS_EQUAL(tensor(i,j+x,k+y,l), twod_patch(patch_loc,0,x,y,0));
    111               }
    112             }
    113           }
    114         }
    115       }
    116     }
    117   }
    118 
    119   patch_dims[0] = 1;
    120   patch_dims[1] = 2;
    121   patch_dims[2] = 3;
    122   patch_dims[3] = 5;
    123   Tensor<float, 5, DataLayout> threed_patch;
    124   threed_patch = tensor.extract_patches(patch_dims);
    125 
    126   if (DataLayout == ColMajor) {
    127     VERIFY_IS_EQUAL(threed_patch.dimension(0), 1);
    128     VERIFY_IS_EQUAL(threed_patch.dimension(1), 2);
    129     VERIFY_IS_EQUAL(threed_patch.dimension(2), 3);
    130     VERIFY_IS_EQUAL(threed_patch.dimension(3), 5);
    131     VERIFY_IS_EQUAL(threed_patch.dimension(4), 2*2*3*3);
    132   } else {
    133     VERIFY_IS_EQUAL(threed_patch.dimension(0), 2*2*3*3);
    134     VERIFY_IS_EQUAL(threed_patch.dimension(1), 1);
    135     VERIFY_IS_EQUAL(threed_patch.dimension(2), 2);
    136     VERIFY_IS_EQUAL(threed_patch.dimension(3), 3);
    137     VERIFY_IS_EQUAL(threed_patch.dimension(4), 5);
    138   }
    139 
    140   for (int i = 0; i < 2; ++i) {
    141     for (int j = 0; j < 2; ++j) {
    142       for (int k = 0; k < 3; ++k) {
    143         for (int l = 0; l < 3; ++l) {
    144           int patch_loc;
    145           if (DataLayout == ColMajor) {
    146             patch_loc = i + 2 * (j + 2 * (k + 3 * l));
    147           } else {
    148             patch_loc = l + 3 * (k + 3 * (j + 2 * i));
    149           }
    150           for (int x = 0; x < 2; ++x) {
    151             for (int y = 0; y < 3; ++y) {
    152               for (int z = 0; z < 5; ++z) {
    153                 if (DataLayout == ColMajor) {
    154                   VERIFY_IS_EQUAL(tensor(i,j+x,k+y,l+z), threed_patch(0,x,y,z,patch_loc));
    155                 } else {
    156                   VERIFY_IS_EQUAL(tensor(i,j+x,k+y,l+z), threed_patch(patch_loc,0,x,y,z));
    157                 }
    158               }
    159             }
    160           }
    161         }
    162       }
    163     }
    164   }
    165 }
    166 
    167 void test_cxx11_tensor_patch()
    168 {
    169    CALL_SUBTEST(test_simple_patch<ColMajor>());
    170    CALL_SUBTEST(test_simple_patch<RowMajor>());
    171    //   CALL_SUBTEST(test_expr_shuffling());
    172 }
    173