1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog (at) gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 struct Generator1D { 15 Generator1D() { } 16 17 float operator()(const array<Eigen::DenseIndex, 1>& coordinates) const { 18 return coordinates[0]; 19 } 20 }; 21 22 template <int DataLayout> 23 static void test_1D() 24 { 25 Tensor<float, 1> vec(6); 26 Tensor<float, 1> result = vec.generate(Generator1D()); 27 28 for (int i = 0; i < 6; ++i) { 29 VERIFY_IS_EQUAL(result(i), i); 30 } 31 } 32 33 34 struct Generator2D { 35 Generator2D() { } 36 37 float operator()(const array<Eigen::DenseIndex, 2>& coordinates) const { 38 return 3 * coordinates[0] + 11 * coordinates[1]; 39 } 40 }; 41 42 template <int DataLayout> 43 static void test_2D() 44 { 45 Tensor<float, 2> matrix(5, 7); 46 Tensor<float, 2> result = matrix.generate(Generator2D()); 47 48 for (int i = 0; i < 5; ++i) { 49 for (int j = 0; j < 5; ++j) { 50 VERIFY_IS_EQUAL(result(i, j), 3*i + 11*j); 51 } 52 } 53 } 54 55 56 template <int DataLayout> 57 static void test_gaussian() 58 { 59 int rows = 32; 60 int cols = 48; 61 array<float, 2> means; 62 means[0] = rows / 2.0f; 63 means[1] = cols / 2.0f; 64 array<float, 2> std_devs; 65 std_devs[0] = 3.14f; 66 std_devs[1] = 2.7f; 67 internal::GaussianGenerator<float, Eigen::DenseIndex, 2> gaussian_gen(means, std_devs); 68 69 Tensor<float, 2> matrix(rows, cols); 70 Tensor<float, 2> result = matrix.generate(gaussian_gen); 71 72 for (int i = 0; i < rows; ++i) { 73 for (int j = 0; j < cols; ++j) { 74 float g_rows = powf(rows/2.0f - i, 2) / (3.14f * 3.14f) * 0.5f; 75 float g_cols = powf(cols/2.0f - j, 2) / (2.7f * 2.7f) * 0.5f; 76 float gaussian = expf(-g_rows - g_cols); 77 VERIFY_IS_EQUAL(result(i, j), gaussian); 78 } 79 } 80 } 81 82 83 void test_cxx11_tensor_generator() 84 { 85 CALL_SUBTEST(test_1D<ColMajor>()); 86 CALL_SUBTEST(test_1D<RowMajor>()); 87 CALL_SUBTEST(test_2D<ColMajor>()); 88 CALL_SUBTEST(test_2D<RowMajor>()); 89 CALL_SUBTEST(test_gaussian<ColMajor>()); 90 CALL_SUBTEST(test_gaussian<RowMajor>()); 91 } 92