Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2015 Ke Yang <yangke (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 
     16 template<int DataLayout>
     17 static void test_simple_inflation()
     18 {
     19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     20   tensor.setRandom();
     21   array<ptrdiff_t, 4> strides;
     22 
     23   strides[0] = 1;
     24   strides[1] = 1;
     25   strides[2] = 1;
     26   strides[3] = 1;
     27 
     28   Tensor<float, 4, DataLayout> no_stride;
     29   no_stride = tensor.inflate(strides);
     30 
     31   VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
     32   VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
     33   VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
     34   VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
     35 
     36   for (int i = 0; i < 2; ++i) {
     37     for (int j = 0; j < 3; ++j) {
     38       for (int k = 0; k < 5; ++k) {
     39         for (int l = 0; l < 7; ++l) {
     40           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
     41         }
     42       }
     43     }
     44   }
     45 
     46   strides[0] = 2;
     47   strides[1] = 4;
     48   strides[2] = 2;
     49   strides[3] = 3;
     50   Tensor<float, 4, DataLayout> inflated;
     51   inflated = tensor.inflate(strides);
     52 
     53   VERIFY_IS_EQUAL(inflated.dimension(0), 3);
     54   VERIFY_IS_EQUAL(inflated.dimension(1), 9);
     55   VERIFY_IS_EQUAL(inflated.dimension(2), 9);
     56   VERIFY_IS_EQUAL(inflated.dimension(3), 19);
     57 
     58   for (int i = 0; i < 3; ++i) {
     59     for (int j = 0; j < 9; ++j) {
     60       for (int k = 0; k < 9; ++k) {
     61         for (int l = 0; l < 19; ++l) {
     62           if (i % 2 == 0 &&
     63               j % 4 == 0 &&
     64               k % 2 == 0 &&
     65               l % 3 == 0) {
     66             VERIFY_IS_EQUAL(inflated(i,j,k,l),
     67                             tensor(i/2, j/4, k/2, l/3));
     68           } else {
     69             VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
     70           }
     71         }
     72       }
     73     }
     74   }
     75 }
     76 
     77 void test_cxx11_tensor_inflation()
     78 {
     79   CALL_SUBTEST(test_simple_inflation<ColMajor>());
     80   CALL_SUBTEST(test_simple_inflation<RowMajor>());
     81 }
     82