1 #ifdef __GNUC__ 2 # pragma GCC diagnostic ignored "-Wmissing-declarations" 3 # if defined __clang__ || defined __APPLE__ 4 # pragma GCC diagnostic ignored "-Wmissing-prototypes" 5 # pragma GCC diagnostic ignored "-Wextra" 6 # endif 7 #endif 8 9 #ifndef __OPENCV_TEST_PRECOMP_HPP__ 10 #define __OPENCV_TEST_PRECOMP_HPP__ 11 12 #include <iostream> 13 #include <map> 14 #include "opencv2/ts.hpp" 15 #include "opencv2/ml.hpp" 16 #include "opencv2/core/core_c.h" 17 18 #define CV_NBAYES "nbayes" 19 #define CV_KNEAREST "knearest" 20 #define CV_SVM "svm" 21 #define CV_EM "em" 22 #define CV_ANN "ann" 23 #define CV_DTREE "dtree" 24 #define CV_BOOST "boost" 25 #define CV_RTREES "rtrees" 26 #define CV_ERTREES "ertrees" 27 28 enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 }; 29 30 using cv::Ptr; 31 using cv::ml::StatModel; 32 using cv::ml::TrainData; 33 using cv::ml::NormalBayesClassifier; 34 using cv::ml::SVM; 35 using cv::ml::KNearest; 36 using cv::ml::ParamGrid; 37 using cv::ml::ANN_MLP; 38 using cv::ml::DTrees; 39 using cv::ml::Boost; 40 using cv::ml::RTrees; 41 42 class CV_MLBaseTest : public cvtest::BaseTest 43 { 44 public: 45 CV_MLBaseTest( const char* _modelName ); 46 virtual ~CV_MLBaseTest(); 47 protected: 48 virtual int read_params( CvFileStorage* fs ); 49 virtual void run( int startFrom ); 50 virtual int prepare_test_case( int testCaseIdx ); 51 virtual std::string& get_validation_filename(); 52 virtual int run_test_case( int testCaseIdx ) = 0; 53 virtual int validate_test_results( int testCaseIdx ) = 0; 54 55 int train( int testCaseIdx ); 56 float get_test_error( int testCaseIdx, std::vector<float> *resp = 0 ); 57 void save( const char* filename ); 58 void load( const char* filename ); 59 60 Ptr<TrainData> data; 61 std::string modelName, validationFN; 62 std::vector<std::string> dataSetNames; 63 cv::FileStorage validationFS; 64 65 Ptr<StatModel> model; 66 67 std::map<int, int> cls_map; 68 69 int64 initSeed; 70 }; 71 72 class CV_AMLTest : public CV_MLBaseTest 73 { 74 public: 75 CV_AMLTest( const char* _modelName ); 76 virtual ~CV_AMLTest() {} 77 protected: 78 virtual int run_test_case( int testCaseIdx ); 79 virtual int validate_test_results( int testCaseIdx ); 80 }; 81 82 class CV_SLMLTest : public CV_MLBaseTest 83 { 84 public: 85 CV_SLMLTest( const char* _modelName ); 86 virtual ~CV_SLMLTest() {} 87 protected: 88 virtual int run_test_case( int testCaseIdx ); 89 virtual int validate_test_results( int testCaseIdx ); 90 91 std::vector<float> test_resps1, test_resps2; // predicted responses for test data 92 std::string fname1, fname2; 93 }; 94 95 #endif 96