Home | History | Annotate | Download | only in test
      1 #include "main.h"
      2 
      3 #include <Eigen/CXX11/Tensor>
      4 
      5 using Eigen::Tensor;
      6 
      7 static void test_single_voxel_patch()
      8 {
      9   Tensor<float, 5> tensor(4,2,3,5,7);
     10   tensor.setRandom();
     11   Tensor<float, 5, RowMajor> tensor_row_major = tensor.swap_layout();
     12 
     13   Tensor<float, 6> single_voxel_patch;
     14   single_voxel_patch = tensor.extract_volume_patches(1, 1, 1);
     15   VERIFY_IS_EQUAL(single_voxel_patch.dimension(0), 4);
     16   VERIFY_IS_EQUAL(single_voxel_patch.dimension(1), 1);
     17   VERIFY_IS_EQUAL(single_voxel_patch.dimension(2), 1);
     18   VERIFY_IS_EQUAL(single_voxel_patch.dimension(3), 1);
     19   VERIFY_IS_EQUAL(single_voxel_patch.dimension(4), 2 * 3 * 5);
     20   VERIFY_IS_EQUAL(single_voxel_patch.dimension(5), 7);
     21 
     22   Tensor<float, 6, RowMajor> single_voxel_patch_row_major;
     23   single_voxel_patch_row_major = tensor_row_major.extract_volume_patches(1, 1, 1);
     24   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(0), 7);
     25   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(1), 2 * 3 * 5);
     26   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(2), 1);
     27   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(3), 1);
     28   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(4), 1);
     29   VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(5), 4);
     30 
     31   for (int i = 0; i < tensor.size(); ++i) {
     32     VERIFY_IS_EQUAL(tensor.data()[i], single_voxel_patch.data()[i]);
     33     VERIFY_IS_EQUAL(tensor_row_major.data()[i], single_voxel_patch_row_major.data()[i]);
     34     VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]);
     35   }
     36 }
     37 
     38 
     39 static void test_entire_volume_patch()
     40 {
     41   const int depth = 4;
     42   const int patch_z = 2;
     43   const int patch_y = 3;
     44   const int patch_x = 5;
     45   const int batch = 7;
     46 
     47   Tensor<float, 5> tensor(depth, patch_z, patch_y, patch_x, batch);
     48   tensor.setRandom();
     49   Tensor<float, 5, RowMajor> tensor_row_major = tensor.swap_layout();
     50 
     51   Tensor<float, 6> entire_volume_patch;
     52   entire_volume_patch = tensor.extract_volume_patches(patch_z, patch_y, patch_x);
     53   VERIFY_IS_EQUAL(entire_volume_patch.dimension(0), depth);
     54   VERIFY_IS_EQUAL(entire_volume_patch.dimension(1), patch_z);
     55   VERIFY_IS_EQUAL(entire_volume_patch.dimension(2), patch_y);
     56   VERIFY_IS_EQUAL(entire_volume_patch.dimension(3), patch_x);
     57   VERIFY_IS_EQUAL(entire_volume_patch.dimension(4), patch_z * patch_y * patch_x);
     58   VERIFY_IS_EQUAL(entire_volume_patch.dimension(5), batch);
     59 
     60   Tensor<float, 6, RowMajor> entire_volume_patch_row_major;
     61   entire_volume_patch_row_major = tensor_row_major.extract_volume_patches(patch_z, patch_y, patch_x);
     62   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(0), batch);
     63   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(1), patch_z * patch_y * patch_x);
     64   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(2), patch_x);
     65   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(3), patch_y);
     66   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(4), patch_z);
     67   VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(5), depth);
     68 
     69   const int dz = patch_z - 1;
     70   const int dy = patch_y - 1;
     71   const int dx = patch_x - 1;
     72 
     73   const int forward_pad_z = dz - dz / 2;
     74   const int forward_pad_y = dy - dy / 2;
     75   const int forward_pad_x = dx - dx / 2;
     76 
     77   for (int pz = 0; pz < patch_z; pz++) {
     78     for (int py = 0; py < patch_y; py++) {
     79       for (int px = 0; px < patch_x; px++) {
     80         const int patchId = pz + patch_z * (py + px * patch_y);
     81         for (int z = 0; z < patch_z; z++) {
     82           for (int y = 0; y < patch_y; y++) {
     83             for (int x = 0; x < patch_x; x++) {
     84               for (int b = 0; b < batch; b++) {
     85                 for (int d = 0; d < depth; d++) {
     86                   float expected = 0.0f;
     87                   float expected_row_major = 0.0f;
     88                   const int eff_z = z - forward_pad_z + pz;
     89                   const int eff_y = y - forward_pad_y + py;
     90                   const int eff_x = x - forward_pad_x + px;
     91                   if (eff_z >= 0 && eff_y >= 0 && eff_x >= 0 &&
     92                       eff_z < patch_z && eff_y < patch_y && eff_x < patch_x) {
     93                     expected = tensor(d, eff_z, eff_y, eff_x, b);
     94                     expected_row_major = tensor_row_major(b, eff_x, eff_y, eff_z, d);
     95                   }
     96                   VERIFY_IS_EQUAL(entire_volume_patch(d, z, y, x, patchId, b), expected);
     97                   VERIFY_IS_EQUAL(entire_volume_patch_row_major(b, patchId, x, y, z, d), expected_row_major);
     98                 }
     99               }
    100             }
    101           }
    102         }
    103       }
    104     }
    105   }
    106 }
    107 
    108 void test_cxx11_tensor_volume_patch()
    109 {
    110   CALL_SUBTEST(test_single_voxel_patch());
    111   CALL_SUBTEST(test_entire_volume_patch());
    112 }
    113