Home | History | Annotate | Download | only in test
      1 #ifdef __GNUC__
      2 #  pragma GCC diagnostic ignored "-Wmissing-declarations"
      3 #  if defined __clang__ || defined __APPLE__
      4 #    pragma GCC diagnostic ignored "-Wmissing-prototypes"
      5 #    pragma GCC diagnostic ignored "-Wextra"
      6 #  endif
      7 #endif
      8 
      9 #ifndef __OPENCV_TEST_PRECOMP_HPP__
     10 #define __OPENCV_TEST_PRECOMP_HPP__
     11 
     12 #include <iostream>
     13 #include <map>
     14 #include "opencv2/ts.hpp"
     15 #include "opencv2/ml.hpp"
     16 #include "opencv2/core/core_c.h"
     17 
     18 #define CV_NBAYES   "nbayes"
     19 #define CV_KNEAREST "knearest"
     20 #define CV_SVM      "svm"
     21 #define CV_EM       "em"
     22 #define CV_ANN      "ann"
     23 #define CV_DTREE    "dtree"
     24 #define CV_BOOST    "boost"
     25 #define CV_RTREES   "rtrees"
     26 #define CV_ERTREES  "ertrees"
     27 
     28 enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
     29 
     30 using cv::Ptr;
     31 using cv::ml::StatModel;
     32 using cv::ml::TrainData;
     33 using cv::ml::NormalBayesClassifier;
     34 using cv::ml::SVM;
     35 using cv::ml::KNearest;
     36 using cv::ml::ParamGrid;
     37 using cv::ml::ANN_MLP;
     38 using cv::ml::DTrees;
     39 using cv::ml::Boost;
     40 using cv::ml::RTrees;
     41 
     42 class CV_MLBaseTest : public cvtest::BaseTest
     43 {
     44 public:
     45     CV_MLBaseTest( const char* _modelName );
     46     virtual ~CV_MLBaseTest();
     47 protected:
     48     virtual int read_params( CvFileStorage* fs );
     49     virtual void run( int startFrom );
     50     virtual int prepare_test_case( int testCaseIdx );
     51     virtual std::string& get_validation_filename();
     52     virtual int run_test_case( int testCaseIdx ) = 0;
     53     virtual int validate_test_results( int testCaseIdx ) = 0;
     54 
     55     int train( int testCaseIdx );
     56     float get_test_error( int testCaseIdx, std::vector<float> *resp = 0 );
     57     void save( const char* filename );
     58     void load( const char* filename );
     59 
     60     Ptr<TrainData> data;
     61     std::string modelName, validationFN;
     62     std::vector<std::string> dataSetNames;
     63     cv::FileStorage validationFS;
     64 
     65     Ptr<StatModel> model;
     66 
     67     std::map<int, int> cls_map;
     68 
     69     int64 initSeed;
     70 };
     71 
     72 class CV_AMLTest : public CV_MLBaseTest
     73 {
     74 public:
     75     CV_AMLTest( const char* _modelName );
     76     virtual ~CV_AMLTest() {}
     77 protected:
     78     virtual int run_test_case( int testCaseIdx );
     79     virtual int validate_test_results( int testCaseIdx );
     80 };
     81 
     82 class CV_SLMLTest : public CV_MLBaseTest
     83 {
     84 public:
     85     CV_SLMLTest( const char* _modelName );
     86     virtual ~CV_SLMLTest() {}
     87 protected:
     88     virtual int run_test_case( int testCaseIdx );
     89     virtual int validate_test_results( int testCaseIdx );
     90 
     91     std::vector<float> test_resps1, test_resps2; // predicted responses for test data
     92     std::string fname1, fname2;
     93 };
     94 
     95 #endif
     96