Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Navdeep Jaitly <ndjaitly (at) google.com and
      5 //                    Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      6 //
      7 // This Source Code Form is subject to the terms of the Mozilla
      8 // Public License v. 2.0. If a copy of the MPL was not distributed
      9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     10 
     11 #include "main.h"
     12 
     13 #include <Eigen/CXX11/Tensor>
     14 
     15 using Eigen::Tensor;
     16 using Eigen::array;
     17 
     18 template <int DataLayout>
     19 static void test_simple_reverse()
     20 {
     21   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     22   tensor.setRandom();
     23 
     24   array<bool, 4> dim_rev;
     25   dim_rev[0] = false;
     26   dim_rev[1] = true;
     27   dim_rev[2] = true;
     28   dim_rev[3] = false;
     29 
     30   Tensor<float, 4, DataLayout> reversed_tensor;
     31   reversed_tensor = tensor.reverse(dim_rev);
     32 
     33   VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
     34   VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
     35   VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
     36   VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
     37 
     38   for (int i = 0; i < 2; ++i) {
     39     for (int j = 0; j < 3; ++j) {
     40       for (int k = 0; k < 5; ++k) {
     41         for (int l = 0; l < 7; ++l) {
     42           VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(i,2-j,4-k,l));
     43         }
     44       }
     45     }
     46   }
     47 
     48   dim_rev[0] = true;
     49   dim_rev[1] = false;
     50   dim_rev[2] = false;
     51   dim_rev[3] = false;
     52 
     53   reversed_tensor = tensor.reverse(dim_rev);
     54 
     55   VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
     56   VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
     57   VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
     58   VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
     59 
     60 
     61   for (int i = 0; i < 2; ++i) {
     62     for (int j = 0; j < 3; ++j) {
     63       for (int k = 0; k < 5; ++k) {
     64         for (int l = 0; l < 7; ++l) {
     65           VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,l));
     66         }
     67       }
     68     }
     69   }
     70 
     71   dim_rev[0] = true;
     72   dim_rev[1] = false;
     73   dim_rev[2] = false;
     74   dim_rev[3] = true;
     75 
     76   reversed_tensor = tensor.reverse(dim_rev);
     77 
     78   VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
     79   VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
     80   VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
     81   VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
     82 
     83 
     84   for (int i = 0; i < 2; ++i) {
     85     for (int j = 0; j < 3; ++j) {
     86       for (int k = 0; k < 5; ++k) {
     87         for (int l = 0; l < 7; ++l) {
     88           VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,6-l));
     89         }
     90       }
     91     }
     92   }
     93 }
     94 
     95 
     96 template <int DataLayout>
     97 static void test_expr_reverse(bool LValue)
     98 {
     99   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
    100   tensor.setRandom();
    101 
    102   array<bool, 4> dim_rev;
    103   dim_rev[0] = false;
    104   dim_rev[1] = true;
    105   dim_rev[2] = false;
    106   dim_rev[3] = true;
    107 
    108   Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
    109   if (LValue) {
    110     expected.reverse(dim_rev) = tensor;
    111   } else {
    112     expected = tensor.reverse(dim_rev);
    113   }
    114 
    115   Tensor<float, 4, DataLayout> result(2,3,5,7);
    116 
    117   array<ptrdiff_t, 4> src_slice_dim;
    118   src_slice_dim[0] = 2;
    119   src_slice_dim[1] = 3;
    120   src_slice_dim[2] = 1;
    121   src_slice_dim[3] = 7;
    122   array<ptrdiff_t, 4> src_slice_start;
    123   src_slice_start[0] = 0;
    124   src_slice_start[1] = 0;
    125   src_slice_start[2] = 0;
    126   src_slice_start[3] = 0;
    127   array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim;
    128   array<ptrdiff_t, 4> dst_slice_start = src_slice_start;
    129 
    130   for (int i = 0; i < 5; ++i) {
    131     if (LValue) {
    132       result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
    133           tensor.slice(src_slice_start, src_slice_dim);
    134     } else {
    135       result.slice(dst_slice_start, dst_slice_dim) =
    136           tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
    137     }
    138     src_slice_start[2] += 1;
    139     dst_slice_start[2] += 1;
    140   }
    141 
    142   VERIFY_IS_EQUAL(result.dimension(0), 2);
    143   VERIFY_IS_EQUAL(result.dimension(1), 3);
    144   VERIFY_IS_EQUAL(result.dimension(2), 5);
    145   VERIFY_IS_EQUAL(result.dimension(3), 7);
    146 
    147   for (int i = 0; i < expected.dimension(0); ++i) {
    148     for (int j = 0; j < expected.dimension(1); ++j) {
    149       for (int k = 0; k < expected.dimension(2); ++k) {
    150         for (int l = 0; l < expected.dimension(3); ++l) {
    151           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
    152         }
    153       }
    154     }
    155   }
    156 
    157   dst_slice_start[2] = 0;
    158   result.setRandom();
    159   for (int i = 0; i < 5; ++i) {
    160      if (LValue) {
    161        result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
    162            tensor.slice(dst_slice_start, dst_slice_dim);
    163      } else {
    164        result.slice(dst_slice_start, dst_slice_dim) =
    165            tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
    166      }
    167     dst_slice_start[2] += 1;
    168   }
    169 
    170   for (int i = 0; i < expected.dimension(0); ++i) {
    171     for (int j = 0; j < expected.dimension(1); ++j) {
    172       for (int k = 0; k < expected.dimension(2); ++k) {
    173         for (int l = 0; l < expected.dimension(3); ++l) {
    174           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
    175         }
    176       }
    177     }
    178   }
    179 }
    180 
    181 
    182 void test_cxx11_tensor_reverse()
    183 {
    184   CALL_SUBTEST(test_simple_reverse<ColMajor>());
    185   CALL_SUBTEST(test_simple_reverse<RowMajor>());
    186   CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
    187   CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
    188   CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
    189   CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
    190 }
    191