Home | History | Annotate | Download | only in cpp
      1 #include "opencv2/ml/ml.hpp"
      2 #include "opencv2/core/core.hpp"
      3 #include "opencv2/core/utility.hpp"
      4 #include <stdio.h>
      5 #include <string>
      6 #include <map>
      7 
      8 using namespace cv;
      9 using namespace cv::ml;
     10 
     11 static void help()
     12 {
     13     printf(
     14         "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
     15         "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
     16         "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
     17         "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
     18         "<csv filename> is the name of training data file in comma-separated value format\n\n");
     19 }
     20 
     21 static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
     22 {
     23     bool ok = model->train(data);
     24     if( !ok )
     25     {
     26         printf("Training failed\n");
     27     }
     28     else
     29     {
     30         printf( "train error: %f\n", model->calcError(data, false, noArray()) );
     31         printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
     32     }
     33 }
     34 
     35 int main(int argc, char** argv)
     36 {
     37     if(argc < 2)
     38     {
     39         help();
     40         return 0;
     41     }
     42     const char* filename = 0;
     43     int response_idx = 0;
     44     std::string typespec;
     45 
     46     for(int i = 1; i < argc; i++)
     47     {
     48         if(strcmp(argv[i], "-r") == 0)
     49             sscanf(argv[++i], "%d", &response_idx);
     50         else if(strcmp(argv[i], "-ts") == 0)
     51             typespec = argv[++i];
     52         else if(argv[i][0] != '-' )
     53             filename = argv[i];
     54         else
     55         {
     56             printf("Error. Invalid option %s\n", argv[i]);
     57             help();
     58             return -1;
     59         }
     60     }
     61 
     62     printf("\nReading in %s...\n\n",filename);
     63     const double train_test_split_ratio = 0.5;
     64 
     65     Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
     66 
     67     if( data.empty() )
     68     {
     69         printf("ERROR: File %s can not be read\n", filename);
     70         return 0;
     71     }
     72 
     73     data->setTrainTestSplitRatio(train_test_split_ratio);
     74 
     75     printf("======DTREE=====\n");
     76     Ptr<DTrees> dtree = DTrees::create();
     77     dtree->setMaxDepth(10);
     78     dtree->setMinSampleCount(2);
     79     dtree->setRegressionAccuracy(0);
     80     dtree->setUseSurrogates(false);
     81     dtree->setMaxCategories(16);
     82     dtree->setCVFolds(0);
     83     dtree->setUse1SERule(false);
     84     dtree->setTruncatePrunedTree(false);
     85     dtree->setPriors(Mat());
     86     train_and_print_errs(dtree, data);
     87 
     88     if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
     89     {
     90         printf("======BOOST=====\n");
     91         Ptr<Boost> boost = Boost::create();
     92         boost->setBoostType(Boost::GENTLE);
     93         boost->setWeakCount(100);
     94         boost->setWeightTrimRate(0.95);
     95         boost->setMaxDepth(2);
     96         boost->setUseSurrogates(false);
     97         boost->setPriors(Mat());
     98         train_and_print_errs(boost, data);
     99     }
    100 
    101     printf("======RTREES=====\n");
    102     Ptr<RTrees> rtrees = RTrees::create();
    103     rtrees->setMaxDepth(10);
    104     rtrees->setMinSampleCount(2);
    105     rtrees->setRegressionAccuracy(0);
    106     rtrees->setUseSurrogates(false);
    107     rtrees->setMaxCategories(16);
    108     rtrees->setPriors(Mat());
    109     rtrees->setCalculateVarImportance(false);
    110     rtrees->setActiveVarCount(0);
    111     rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
    112     train_and_print_errs(rtrees, data);
    113 
    114     return 0;
    115 }
    116