Home | History | Annotate | Download | only in test
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 // (including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort (including negligence or otherwise) arising in any way out of
     38 // the use of this software, even if advised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 #include "test_precomp.hpp"
     43 
     44 #include <iostream>
     45 #include <fstream>
     46 
     47 using namespace cv;
     48 using namespace std;
     49 
     50 CV_SLMLTest::CV_SLMLTest( const char* _modelName ) : CV_MLBaseTest( _modelName )
     51 {
     52     validationFN = "slvalidation.xml";
     53 }
     54 
     55 int CV_SLMLTest::run_test_case( int testCaseIdx )
     56 {
     57     int code = cvtest::TS::OK;
     58     code = prepare_test_case( testCaseIdx );
     59 
     60     if( code == cvtest::TS::OK )
     61     {
     62         data->setTrainTestSplit(data->getNTrainSamples(), true);
     63         code = train( testCaseIdx );
     64         if( code == cvtest::TS::OK )
     65         {
     66             get_test_error( testCaseIdx, &test_resps1 );
     67             fname1 = tempfile(".yml.gz");
     68             save( fname1.c_str() );
     69             load( fname1.c_str() );
     70             get_test_error( testCaseIdx, &test_resps2 );
     71             fname2 = tempfile(".yml.gz");
     72             save( fname2.c_str() );
     73         }
     74         else
     75             ts->printf( cvtest::TS::LOG, "model can not be trained" );
     76     }
     77     return code;
     78 }
     79 
     80 int CV_SLMLTest::validate_test_results( int testCaseIdx )
     81 {
     82     int code = cvtest::TS::OK;
     83 
     84     // 1. compare files
     85     FILE *fs1 = fopen(fname1.c_str(), "rb"), *fs2 = fopen(fname2.c_str(), "rb");
     86     size_t sz1 = 0, sz2 = 0;
     87     if( !fs1 || !fs2 )
     88         code = cvtest::TS::FAIL_MISSING_TEST_DATA;
     89     if( code >= 0 )
     90     {
     91         fseek(fs1, 0, SEEK_END); fseek(fs2, 0, SEEK_END);
     92         sz1 = ftell(fs1);
     93         sz2 = ftell(fs2);
     94         fseek(fs1, 0, SEEK_SET); fseek(fs2, 0, SEEK_SET);
     95     }
     96 
     97     if( sz1 != sz2 )
     98         code = cvtest::TS::FAIL_INVALID_OUTPUT;
     99 
    100     if( code >= 0 )
    101     {
    102         const int BUFSZ = 1024;
    103         uchar buf1[BUFSZ], buf2[BUFSZ];
    104         for( size_t pos = 0; pos < sz1;  )
    105         {
    106             size_t r1 = fread(buf1, 1, BUFSZ, fs1);
    107             size_t r2 = fread(buf2, 1, BUFSZ, fs2);
    108             if( r1 != r2 || memcmp(buf1, buf2, r1) != 0 )
    109             {
    110                 ts->printf( cvtest::TS::LOG,
    111                            "in test case %d first (%s) and second (%s) saved files differ in %d-th kb\n",
    112                            testCaseIdx, fname1.c_str(), fname2.c_str(),
    113                            (int)pos );
    114                 code = cvtest::TS::FAIL_INVALID_OUTPUT;
    115                 break;
    116             }
    117             pos += r1;
    118         }
    119     }
    120 
    121     if(fs1)
    122         fclose(fs1);
    123     if(fs2)
    124         fclose(fs2);
    125 
    126     // delete temporary files
    127     if( code >= 0 )
    128     {
    129         remove( fname1.c_str() );
    130         remove( fname2.c_str() );
    131     }
    132 
    133     if( code >= 0 )
    134     {
    135         // 2. compare responses
    136         CV_Assert( test_resps1.size() == test_resps2.size() );
    137         vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
    138         for( ; it1 != test_resps1.end(); ++it1, ++it2 )
    139         {
    140             if( fabs(*it1 - *it2) > FLT_EPSILON )
    141             {
    142                 ts->printf( cvtest::TS::LOG, "in test case %d responses predicted before saving and after loading is different", testCaseIdx );
    143                 code = cvtest::TS::FAIL_INVALID_OUTPUT;
    144                 break;
    145             }
    146         }
    147     }
    148     return code;
    149 }
    150 
    151 TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); }
    152 TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); }
    153 TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); }
    154 TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); }
    155 TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
    156 TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
    157 TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
    158 TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
    159 
    160 class CV_LegacyTest : public cvtest::BaseTest
    161 {
    162 public:
    163     CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
    164         : cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
    165     {
    166     }
    167     virtual ~CV_LegacyTest() {}
    168 protected:
    169     void run(int)
    170     {
    171         unsigned int idx = 0;
    172         for (;;)
    173         {
    174             if (idx >= suffixes.size())
    175                 break;
    176             int found = (int)suffixes.find(';', idx);
    177             string piece = suffixes.substr(idx, found - idx);
    178             if (piece.empty())
    179                 break;
    180             oneTest(piece);
    181             idx += (unsigned int)piece.size() + 1;
    182         }
    183     }
    184     void oneTest(const string & suffix)
    185     {
    186         using namespace cv::ml;
    187 
    188         int code = cvtest::TS::OK;
    189         string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
    190         bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
    191         Ptr<StatModel> model;
    192         if (modelName == CV_BOOST)
    193             model = Algorithm::load<Boost>(filename);
    194         else if (modelName == CV_ANN)
    195             model = Algorithm::load<ANN_MLP>(filename);
    196         else if (modelName == CV_DTREE)
    197             model = Algorithm::load<DTrees>(filename);
    198         else if (modelName == CV_NBAYES)
    199             model = Algorithm::load<NormalBayesClassifier>(filename);
    200         else if (modelName == CV_SVM)
    201             model = Algorithm::load<SVM>(filename);
    202         else if (modelName == CV_RTREES)
    203             model = Algorithm::load<RTrees>(filename);
    204         if (!model)
    205         {
    206             code = cvtest::TS::FAIL_INVALID_TEST_DATA;
    207         }
    208         else
    209         {
    210             Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
    211             ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
    212 
    213             if (isTree)
    214                 randomFillCategories(filename, input);
    215 
    216             Mat output;
    217             model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
    218             // just check if no internal assertions or errors thrown
    219         }
    220         ts->set_failed_test_info(code);
    221     }
    222     void randomFillCategories(const string & filename, Mat & input)
    223     {
    224         Mat catMap;
    225         Mat catCount;
    226         std::vector<uchar> varTypes;
    227 
    228         FileStorage fs(filename, FileStorage::READ);
    229         FileNode root = fs.getFirstTopLevelNode();
    230         root["cat_map"] >> catMap;
    231         root["cat_count"] >> catCount;
    232         root["var_type"] >> varTypes;
    233 
    234         int offset = 0;
    235         int countOffset = 0;
    236         uint var = 0, varCount = (uint)varTypes.size();
    237         for (; var < varCount; ++var)
    238         {
    239             if (varTypes[var] == ml::VAR_CATEGORICAL)
    240             {
    241                 int size = catCount.at<int>(0, countOffset);
    242                 for (int row = 0; row < input.rows; ++row)
    243                 {
    244                     int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
    245                     int value = catMap.at<int>(0, randomChosenIndex);
    246                     input.at<float>(row, var) = (float)value;
    247                 }
    248                 offset += size;
    249                 ++countOffset;
    250             }
    251         }
    252     }
    253     string modelName;
    254     string suffixes;
    255 };
    256 
    257 TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
    258 TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
    259 TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
    260 TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
    261 TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
    262 TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
    263 
    264 /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
    265 {
    266     Ptr<cv::ml::SVM> svm;
    267     string filename = tempfile("svm.xml");
    268     ASSERT_THROW(svm.save(filename.c_str()), Exception);
    269     remove(filename.c_str());
    270 }*/
    271 
    272 TEST(DISABLED_ML_SVM, linear_save_load)
    273 {
    274     Ptr<cv::ml::SVM> svm1, svm2, svm3;
    275 
    276     svm1 = Algorithm::load<SVM>("SVM45_X_38-1.xml");
    277     svm2 = Algorithm::load<SVM>("SVM45_X_38-2.xml");
    278     string tname = tempfile("a.xml");
    279     svm2->save(tname);
    280     svm3 = Algorithm::load<SVM>(tname);
    281 
    282     ASSERT_EQ(svm1->getVarCount(), svm2->getVarCount());
    283     ASSERT_EQ(svm1->getVarCount(), svm3->getVarCount());
    284 
    285     int m = 10000, n = svm1->getVarCount();
    286     Mat samples(m, n, CV_32F), r1, r2, r3;
    287     randu(samples, 0., 1.);
    288 
    289     svm1->predict(samples, r1);
    290     svm2->predict(samples, r2);
    291     svm3->predict(samples, r3);
    292 
    293     double eps = 1e-4;
    294     EXPECT_LE(norm(r1, r2, NORM_INF), eps);
    295     EXPECT_LE(norm(r1, r3, NORM_INF), eps);
    296 
    297     remove(tname.c_str());
    298 }
    299 
    300 /* End of file. */
    301