Home | History | Annotate | Download | only in test
      1 
      2 #include "test_precomp.hpp"
      3 
      4 #if 0
      5 
      6 #include <string>
      7 #include <fstream>
      8 #include <iostream>
      9 
     10 using namespace std;
     11 
     12 
     13 class CV_GBTreesTest : public cvtest::BaseTest
     14 {
     15 public:
     16     CV_GBTreesTest();
     17     ~CV_GBTreesTest();
     18 
     19 protected:
     20     void run(int);
     21 
     22     int TestTrainPredict(int test_num);
     23     int TestSaveLoad();
     24 
     25     int checkPredictError(int test_num);
     26     int checkLoadSave();
     27 
     28     string model_file_name1;
     29     string model_file_name2;
     30 
     31     string* datasets;
     32     string data_path;
     33 
     34     CvMLData* data;
     35     CvGBTrees* gtb;
     36 
     37     vector<float> test_resps1;
     38     vector<float> test_resps2;
     39 
     40     int64 initSeed;
     41 };
     42 
     43 
     44 int _get_len(const CvMat* mat)
     45 {
     46     return (mat->cols > mat->rows) ? mat->cols : mat->rows;
     47 }
     48 
     49 
     50 CV_GBTreesTest::CV_GBTreesTest()
     51 {
     52     int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
     53                       CV_BIG_INT(0x0000a17166072c7c),
     54                       CV_BIG_INT(0x0201b32115cd1f9a),
     55                       CV_BIG_INT(0x0513cb37abcd1234),
     56                       CV_BIG_INT(0x0001a2b3c4d5f678)
     57                     };
     58 
     59     int seedCount = sizeof(seeds)/sizeof(seeds[0]);
     60     cv::RNG& rng = cv::theRNG();
     61     initSeed = rng.state;
     62     rng.state = seeds[rng(seedCount)];
     63 
     64     datasets = 0;
     65     data = 0;
     66     gtb = 0;
     67 }
     68 
     69 CV_GBTreesTest::~CV_GBTreesTest()
     70 {
     71     if (data)
     72         delete data;
     73     delete[] datasets;
     74     cv::theRNG().state = initSeed;
     75 }
     76 
     77 
     78 int CV_GBTreesTest::TestTrainPredict(int test_num)
     79 {
     80     int code = cvtest::TS::OK;
     81 
     82     int weak_count = 200;
     83     float shrinkage = 0.1f;
     84     float subsample_portion = 0.5f;
     85     int max_depth = 5;
     86     bool use_surrogates = false;
     87     int loss_function_type = 0;
     88     switch (test_num)
     89     {
     90         case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break;
     91         case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break;
     92         case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break;
     93         case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break;
     94         default  :
     95             {
     96             ts->printf( cvtest::TS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." );
     97             return cvtest::TS::FAIL_BAD_ARG_CHECK;
     98             }
     99     }
    100 
    101     int dataset_num = test_num == 0 ? 0 : 1;
    102     if (!data)
    103     {
    104         data = new CvMLData();
    105         data->set_delimiter(',');
    106 
    107         if (data->read_csv(datasets[dataset_num].c_str()))
    108         {
    109             ts->printf( cvtest::TS::LOG, "File reading error." );
    110             return cvtest::TS::FAIL_INVALID_TEST_DATA;
    111         }
    112 
    113         if (test_num == 0)
    114         {
    115             data->set_response_idx(57);
    116             data->set_var_types("ord[0-56],cat[57]");
    117         }
    118         else
    119         {
    120             data->set_response_idx(13);
    121             data->set_var_types("ord[0-2,4-13],cat[3]");
    122             subsample_portion = 0.7f;
    123         }
    124 
    125         int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f);
    126         CvTrainTestSplit spl( train_sample_count );
    127         data->set_train_test_split( &spl );
    128     }
    129 
    130     data->mix_train_and_test_idx();
    131 
    132 
    133     if (gtb) delete gtb;
    134     gtb = new CvGBTrees();
    135     bool tmp_code = true;
    136     tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count,
    137                           shrinkage, subsample_portion,
    138                           max_depth, use_surrogates));
    139 
    140     if (!tmp_code)
    141     {
    142         ts->printf( cvtest::TS::LOG, "Model training was failed.");
    143         return cvtest::TS::FAIL_INVALID_OUTPUT;
    144     }
    145 
    146     code = checkPredictError(test_num);
    147 
    148     return code;
    149 
    150 }
    151 
    152 
    153 int CV_GBTreesTest::checkPredictError(int test_num)
    154 {
    155     if (!gtb)
    156         return cvtest::TS::FAIL_GENERIC;
    157 
    158     //float mean[] = {5.430247f, 13.5654f, 12.6569f, 13.1661f};
    159     //float sigma[] = {0.4162694f, 3.21161f, 3.43297f, 3.00624f};
    160     float mean[] = {5.80226f, 12.68689f, 13.49095f, 13.19628f};
    161     float sigma[] = {0.4764534f, 3.166919f, 3.022405f, 2.868722f};
    162 
    163     float current_error = gtb->calc_error(data, CV_TEST_ERROR);
    164 
    165     if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] )
    166     {
    167         ts->printf( cvtest::TS::LOG, "Test error is out of range:\n"
    168                     "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/",
    169                     current_error, mean[test_num], 6*sigma[test_num] );
    170         return cvtest::TS::FAIL_BAD_ACCURACY;
    171     }
    172 
    173     return cvtest::TS::OK;
    174 
    175 }
    176 
    177 
    178 int CV_GBTreesTest::TestSaveLoad()
    179 {
    180     if (!gtb)
    181         return cvtest::TS::FAIL_GENERIC;
    182 
    183     model_file_name1 = cv::tempfile();
    184     model_file_name2 = cv::tempfile();
    185 
    186     gtb->save(model_file_name1.c_str());
    187     gtb->calc_error(data, CV_TEST_ERROR, &test_resps1);
    188     gtb->load(model_file_name1.c_str());
    189     gtb->calc_error(data, CV_TEST_ERROR, &test_resps2);
    190     gtb->save(model_file_name2.c_str());
    191 
    192     return checkLoadSave();
    193 
    194 }
    195 
    196 
    197 
    198 int CV_GBTreesTest::checkLoadSave()
    199 {
    200     int code = cvtest::TS::OK;
    201 
    202     // 1. compare files
    203     ifstream f1( model_file_name1.c_str() ), f2( model_file_name2.c_str() );
    204     string s1, s2;
    205     int lineIdx = 0;
    206     CV_Assert( f1.is_open() && f2.is_open() );
    207     for( ; !f1.eof() && !f2.eof(); lineIdx++ )
    208     {
    209         getline( f1, s1 );
    210         getline( f2, s2 );
    211         if( s1.compare(s2) )
    212         {
    213             ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
    214                lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
    215             code = cvtest::TS::FAIL_INVALID_OUTPUT;
    216         }
    217     }
    218     if( !f1.eof() || !f2.eof() )
    219     {
    220         ts->printf( cvtest::TS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
    221             lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
    222         code = cvtest::TS::FAIL_INVALID_OUTPUT;
    223     }
    224     f1.close();
    225     f2.close();
    226     // delete temporary files
    227     remove( model_file_name1.c_str() );
    228     remove( model_file_name2.c_str() );
    229 
    230     // 2. compare responses
    231     CV_Assert( test_resps1.size() == test_resps2.size() );
    232     vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
    233     for( ; it1 != test_resps1.end(); ++it1, ++it2 )
    234     {
    235         if( fabs(*it1 - *it2) > FLT_EPSILON )
    236         {
    237             ts->printf( cvtest::TS::LOG, "Responses predicted before saving and after loading are different" );
    238             code = cvtest::TS::FAIL_INVALID_OUTPUT;
    239         }
    240     }
    241     return code;
    242 }
    243 
    244 
    245 
    246 void CV_GBTreesTest::run(int)
    247 {
    248 
    249     string dataPath = string(ts->get_data_path());
    250     datasets = new string[2];
    251     datasets[0] = dataPath + string("spambase.data"); /*string("dataset_classification.csv");*/
    252     datasets[1] = dataPath + string("housing_.data");  /*string("dataset_regression.csv");*/
    253 
    254     int code = cvtest::TS::OK;
    255 
    256     for (int i = 0; i < 4; i++)
    257     {
    258 
    259         int temp_code = TestTrainPredict(i);
    260         if (temp_code != cvtest::TS::OK)
    261         {
    262             code = temp_code;
    263             break;
    264         }
    265 
    266         else if (i==0)
    267         {
    268             temp_code = TestSaveLoad();
    269             if (temp_code != cvtest::TS::OK)
    270                 code = temp_code;
    271             delete data;
    272             data = 0;
    273         }
    274 
    275         delete gtb;
    276         gtb = 0;
    277     }
    278     delete data;
    279     data = 0;
    280 
    281     ts->set_failed_test_info( code );
    282 }
    283 
    284 /////////////////////////////////////////////////////////////////////////////
    285 //////////////////// test registration  /////////////////////////////////////
    286 /////////////////////////////////////////////////////////////////////////////
    287 
    288 TEST(ML_GBTrees, regression) { CV_GBTreesTest test; test.safe_run(); }
    289 
    290 #endif
    291