Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 using Eigen::RowMajor;
     16 
     17 static void test_simple_lvalue_ref()
     18 {
     19   Tensor<int, 1> input(6);
     20   input.setRandom();
     21 
     22   TensorRef<Tensor<int, 1>> ref3(input);
     23   TensorRef<Tensor<int, 1>> ref4 = input;
     24 
     25   VERIFY_IS_EQUAL(ref3.data(), input.data());
     26   VERIFY_IS_EQUAL(ref4.data(), input.data());
     27 
     28   for (int i = 0; i < 6; ++i) {
     29     VERIFY_IS_EQUAL(ref3(i), input(i));
     30     VERIFY_IS_EQUAL(ref4(i), input(i));
     31   }
     32 
     33   for (int i = 0; i < 6; ++i) {
     34     ref3.coeffRef(i) = i;
     35   }
     36   for (int i = 0; i < 6; ++i) {
     37     VERIFY_IS_EQUAL(input(i), i);
     38   }
     39   for (int i = 0; i < 6; ++i) {
     40     ref4.coeffRef(i) = -i * 2;
     41   }
     42   for (int i = 0; i < 6; ++i) {
     43     VERIFY_IS_EQUAL(input(i), -i*2);
     44   }
     45 }
     46 
     47 
     48 static void test_simple_rvalue_ref()
     49 {
     50   Tensor<int, 1> input1(6);
     51   input1.setRandom();
     52   Tensor<int, 1> input2(6);
     53   input2.setRandom();
     54 
     55   TensorRef<Tensor<int, 1>> ref3(input1 + input2);
     56   TensorRef<Tensor<int, 1>> ref4 = input1 + input2;
     57 
     58   VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data());
     59   VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data());
     60   VERIFY_IS_NOT_EQUAL(ref3.data(), input2.data());
     61   VERIFY_IS_NOT_EQUAL(ref4.data(), input2.data());
     62 
     63   for (int i = 0; i < 6; ++i) {
     64     VERIFY_IS_EQUAL(ref3(i), input1(i) + input2(i));
     65     VERIFY_IS_EQUAL(ref4(i), input1(i) + input2(i));
     66   }
     67 }
     68 
     69 
     70 static void test_multiple_dims()
     71 {
     72   Tensor<float, 3> input(3,5,7);
     73   input.setRandom();
     74 
     75   TensorRef<Tensor<float, 3>> ref(input);
     76   VERIFY_IS_EQUAL(ref.data(), input.data());
     77   VERIFY_IS_EQUAL(ref.dimension(0), 3);
     78   VERIFY_IS_EQUAL(ref.dimension(1), 5);
     79   VERIFY_IS_EQUAL(ref.dimension(2), 7);
     80 
     81   for (int i = 0; i < 3; ++i) {
     82     for (int j = 0; j < 5; ++j) {
     83       for (int k = 0; k < 7; ++k) {
     84         VERIFY_IS_EQUAL(ref(i,j,k), input(i,j,k));
     85       }
     86     }
     87   }
     88 }
     89 
     90 
     91 static void test_slice()
     92 {
     93   Tensor<float, 5> tensor(2,3,5,7,11);
     94   tensor.setRandom();
     95 
     96   Eigen::DSizes<ptrdiff_t, 5> indices(1,2,3,4,5);
     97   Eigen::DSizes<ptrdiff_t, 5> sizes(1,1,1,1,1);
     98   TensorRef<Tensor<float, 5>> slice = tensor.slice(indices, sizes);
     99   VERIFY_IS_EQUAL(slice(0,0,0,0,0), tensor(1,2,3,4,5));
    100 
    101   Eigen::DSizes<ptrdiff_t, 5> indices2(1,1,3,4,5);
    102   Eigen::DSizes<ptrdiff_t, 5> sizes2(1,1,2,2,3);
    103   slice = tensor.slice(indices2, sizes2);
    104   for (int i = 0; i < 2; ++i) {
    105     for (int j = 0; j < 2; ++j) {
    106       for (int k = 0; k < 3; ++k) {
    107         VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k));
    108       }
    109     }
    110   }
    111 
    112   Eigen::DSizes<ptrdiff_t, 5> indices3(0,0,0,0,0);
    113   Eigen::DSizes<ptrdiff_t, 5> sizes3(2,3,1,1,1);
    114   slice = tensor.slice(indices3, sizes3);
    115   VERIFY_IS_EQUAL(slice.data(), tensor.data());
    116 }
    117 
    118 
    119 static void test_ref_of_ref()
    120 {
    121   Tensor<float, 3> input(3,5,7);
    122   input.setRandom();
    123 
    124   TensorRef<Tensor<float, 3>> ref(input);
    125   TensorRef<Tensor<float, 3>> ref_of_ref(ref);
    126   TensorRef<Tensor<float, 3>> ref_of_ref2;
    127   ref_of_ref2 = ref;
    128 
    129   VERIFY_IS_EQUAL(ref_of_ref.data(), input.data());
    130   VERIFY_IS_EQUAL(ref_of_ref.dimension(0), 3);
    131   VERIFY_IS_EQUAL(ref_of_ref.dimension(1), 5);
    132   VERIFY_IS_EQUAL(ref_of_ref.dimension(2), 7);
    133 
    134   VERIFY_IS_EQUAL(ref_of_ref2.data(), input.data());
    135   VERIFY_IS_EQUAL(ref_of_ref2.dimension(0), 3);
    136   VERIFY_IS_EQUAL(ref_of_ref2.dimension(1), 5);
    137   VERIFY_IS_EQUAL(ref_of_ref2.dimension(2), 7);
    138 
    139   for (int i = 0; i < 3; ++i) {
    140     for (int j = 0; j < 5; ++j) {
    141       for (int k = 0; k < 7; ++k) {
    142         VERIFY_IS_EQUAL(ref_of_ref(i,j,k), input(i,j,k));
    143         VERIFY_IS_EQUAL(ref_of_ref2(i,j,k), input(i,j,k));
    144      }
    145     }
    146   }
    147 }
    148 
    149 
    150 static void test_ref_in_expr()
    151 {
    152   Tensor<float, 3> input(3,5,7);
    153   input.setRandom();
    154   TensorRef<Tensor<float, 3>> input_ref(input);
    155 
    156   Tensor<float, 3> result(3,5,7);
    157   result.setRandom();
    158   TensorRef<Tensor<float, 3>> result_ref(result);
    159 
    160   Tensor<float, 3> bias(3,5,7);
    161   bias.setRandom();
    162 
    163   result_ref = input_ref + bias;
    164   for (int i = 0; i < 3; ++i) {
    165     for (int j = 0; j < 5; ++j) {
    166       for (int k = 0; k < 7; ++k) {
    167         VERIFY_IS_EQUAL(result_ref(i,j,k), input(i,j,k) + bias(i,j,k));
    168         VERIFY_IS_NOT_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
    169       }
    170     }
    171   }
    172 
    173   result = result_ref;
    174   for (int i = 0; i < 3; ++i) {
    175     for (int j = 0; j < 5; ++j) {
    176       for (int k = 0; k < 7; ++k) {
    177         VERIFY_IS_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
    178       }
    179     }
    180   }
    181 }
    182 
    183 
    184 static void test_coeff_ref()
    185 {
    186   Tensor<float, 5> tensor(2,3,5,7,11);
    187   tensor.setRandom();
    188   Tensor<float, 5> original = tensor;
    189 
    190   TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
    191   slice.coeffRef(0, 0, 0, 0) = 1.0f;
    192   slice.coeffRef(1, 0, 0, 0) += 2.0f;
    193 
    194   VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
    195   VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
    196 }
    197 
    198 
    199 static void test_nested_ops_with_ref()
    200 {
    201   Tensor<float, 4> t(2, 3, 5, 7);
    202   t.setRandom();
    203   TensorMap<Tensor<const float, 4> > m(t.data(), 2, 3, 5, 7);
    204   array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
    205   paddings[0] = std::make_pair(0, 0);
    206   paddings[1] = std::make_pair(2, 1);
    207   paddings[2] = std::make_pair(3, 4);
    208   paddings[3] = std::make_pair(0, 0);
    209   DSizes<Eigen::DenseIndex, 4> shuffle_dims(0, 1, 2, 3);
    210   TensorRef<Tensor<const float, 4> > ref(m.pad(paddings));
    211   array<std::pair<ptrdiff_t, ptrdiff_t>, 4> trivial;
    212   trivial[0] = std::make_pair(0, 0);
    213   trivial[1] = std::make_pair(0, 0);
    214   trivial[2] = std::make_pair(0, 0);
    215   trivial[3] = std::make_pair(0, 0);
    216   Tensor<float, 4> padded = ref.shuffle(shuffle_dims).pad(trivial);
    217   VERIFY_IS_EQUAL(padded.dimension(0), 2+0);
    218   VERIFY_IS_EQUAL(padded.dimension(1), 3+3);
    219   VERIFY_IS_EQUAL(padded.dimension(2), 5+7);
    220   VERIFY_IS_EQUAL(padded.dimension(3), 7+0);
    221 
    222   for (int i = 0; i < 2; ++i) {
    223     for (int j = 0; j < 6; ++j) {
    224       for (int k = 0; k < 12; ++k) {
    225         for (int l = 0; l < 7; ++l) {
    226           if (j >= 2 && j < 5 && k >= 3 && k < 8) {
    227             VERIFY_IS_EQUAL(padded(i,j,k,l), t(i,j-2,k-3,l));
    228           } else {
    229             VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f);
    230           }
    231         }
    232       }
    233     }
    234   }
    235 }
    236 
    237 
    238 void test_cxx11_tensor_ref()
    239 {
    240   CALL_SUBTEST(test_simple_lvalue_ref());
    241   CALL_SUBTEST(test_simple_rvalue_ref());
    242   CALL_SUBTEST(test_multiple_dims());
    243   CALL_SUBTEST(test_slice());
    244   CALL_SUBTEST(test_ref_of_ref());
    245   CALL_SUBTEST(test_ref_in_expr());
    246   CALL_SUBTEST(test_coeff_ref());
    247   CALL_SUBTEST(test_nested_ops_with_ref());
    248 }
    249