1 #include "opencv2/ml/ml.hpp" 2 #include "opencv2/core/core.hpp" 3 #include "opencv2/core/utility.hpp" 4 #include <stdio.h> 5 #include <string> 6 #include <map> 7 8 using namespace cv; 9 using namespace cv::ml; 10 11 static void help() 12 { 13 printf( 14 "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n" 15 "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n" 16 "where -r <response_column> specified the 0-based index of the response (0 by default)\n" 17 "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n" 18 "<csv filename> is the name of training data file in comma-separated value format\n\n"); 19 } 20 21 static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data) 22 { 23 bool ok = model->train(data); 24 if( !ok ) 25 { 26 printf("Training failed\n"); 27 } 28 else 29 { 30 printf( "train error: %f\n", model->calcError(data, false, noArray()) ); 31 printf( "test error: %f\n\n", model->calcError(data, true, noArray()) ); 32 } 33 } 34 35 int main(int argc, char** argv) 36 { 37 if(argc < 2) 38 { 39 help(); 40 return 0; 41 } 42 const char* filename = 0; 43 int response_idx = 0; 44 std::string typespec; 45 46 for(int i = 1; i < argc; i++) 47 { 48 if(strcmp(argv[i], "-r") == 0) 49 sscanf(argv[++i], "%d", &response_idx); 50 else if(strcmp(argv[i], "-ts") == 0) 51 typespec = argv[++i]; 52 else if(argv[i][0] != '-' ) 53 filename = argv[i]; 54 else 55 { 56 printf("Error. Invalid option %s\n", argv[i]); 57 help(); 58 return -1; 59 } 60 } 61 62 printf("\nReading in %s...\n\n",filename); 63 const double train_test_split_ratio = 0.5; 64 65 Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec); 66 67 if( data.empty() ) 68 { 69 printf("ERROR: File %s can not be read\n", filename); 70 return 0; 71 } 72 73 data->setTrainTestSplitRatio(train_test_split_ratio); 74 75 printf("======DTREE=====\n"); 76 Ptr<DTrees> dtree = DTrees::create(); 77 dtree->setMaxDepth(10); 78 dtree->setMinSampleCount(2); 79 dtree->setRegressionAccuracy(0); 80 dtree->setUseSurrogates(false); 81 dtree->setMaxCategories(16); 82 dtree->setCVFolds(0); 83 dtree->setUse1SERule(false); 84 dtree->setTruncatePrunedTree(false); 85 dtree->setPriors(Mat()); 86 train_and_print_errs(dtree, data); 87 88 if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem 89 { 90 printf("======BOOST=====\n"); 91 Ptr<Boost> boost = Boost::create(); 92 boost->setBoostType(Boost::GENTLE); 93 boost->setWeakCount(100); 94 boost->setWeightTrimRate(0.95); 95 boost->setMaxDepth(2); 96 boost->setUseSurrogates(false); 97 boost->setPriors(Mat()); 98 train_and_print_errs(boost, data); 99 } 100 101 printf("======RTREES=====\n"); 102 Ptr<RTrees> rtrees = RTrees::create(); 103 rtrees->setMaxDepth(10); 104 rtrees->setMinSampleCount(2); 105 rtrees->setRegressionAccuracy(0); 106 rtrees->setUseSurrogates(false); 107 rtrees->setMaxCategories(16); 108 rtrees->setPriors(Mat()); 109 rtrees->setCalculateVarImportance(false); 110 rtrees->setActiveVarCount(0); 111 rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0)); 112 train_and_print_errs(rtrees, data); 113 114 return 0; 115 } 116