Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                           License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     15 // Third party copyrights are property of their respective owners.
     16 //
     17 // Redistribution and use in source and binary forms, with or without modification,
     18 // are permitted provided that the following conditions are met:
     19 //
     20 //   * Redistribution's of source code must retain the above copyright notice,
     21 //     this list of conditions and the following disclaimer.
     22 //
     23 //   * Redistribution's in binary form must reproduce the above copyright notice,
     24 //     this list of conditions and the following disclaimer in the documentation
     25 //     and/or other materials provided with the distribution.
     26 //
     27 //   * The name of the copyright holders may not be used to endorse or promote products
     28 //     derived from this software without specific prior written permission.
     29 //
     30 // This software is provided by the copyright holders and contributors "as is" and
     31 // any express or implied warranties, including, but not limited to, the implied
     32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     33 // In no event shall the Intel Corporation or contributors be liable for any direct,
     34 // indirect, incidental, special, exemplary, or consequential damages
     35 // (including, but not limited to, procurement of substitute goods or services;
     36 // loss of use, data, or profits; or business interruption) however caused
     37 // and on any theory of liability, whether in contract, strict liability,
     38 // or tort (including negligence or otherwise) arising in any way out of
     39 // the use of this software, even if advised of the possibility of such damage.
     40 //
     41 //M*/
     42 
     43 #include "precomp.hpp"
     44 
     45 namespace cv {
     46 namespace ml {
     47 
     48 //////////////////////////////////////////////////////////////////////////////////////////
     49 //                                  Random trees                                        //
     50 //////////////////////////////////////////////////////////////////////////////////////////
     51 RTreeParams::RTreeParams()
     52 {
     53     calcVarImportance = false;
     54     nactiveVars = 0;
     55     termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
     56 }
     57 
     58 RTreeParams::RTreeParams(bool _calcVarImportance,
     59                          int _nactiveVars,
     60                          TermCriteria _termCrit )
     61 {
     62     calcVarImportance = _calcVarImportance;
     63     nactiveVars = _nactiveVars;
     64     termCrit = _termCrit;
     65 }
     66 
     67 
     68 class DTreesImplForRTrees : public DTreesImpl
     69 {
     70 public:
     71     DTreesImplForRTrees()
     72     {
     73         params.setMaxDepth(5);
     74         params.setMinSampleCount(10);
     75         params.setRegressionAccuracy(0.f);
     76         params.useSurrogates = false;
     77         params.setMaxCategories(10);
     78         params.setCVFolds(0);
     79         params.use1SERule = false;
     80         params.truncatePrunedTree = false;
     81         params.priors = Mat();
     82     }
     83     virtual ~DTreesImplForRTrees() {}
     84 
     85     void clear()
     86     {
     87         DTreesImpl::clear();
     88         oobError = 0.;
     89         rng = RNG((uint64)-1);
     90     }
     91 
     92     const vector<int>& getActiveVars()
     93     {
     94         int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
     95         for( i = 0; i < nvars; i++ )
     96         {
     97             int i1 = rng.uniform(0, nvars);
     98             int i2 = rng.uniform(0, nvars);
     99             std::swap(allVars[i1], allVars[i2]);
    100         }
    101         for( i = 0; i < m; i++ )
    102             activeVars[i] = allVars[i];
    103         return activeVars;
    104     }
    105 
    106     void startTraining( const Ptr<TrainData>& trainData, int flags )
    107     {
    108         DTreesImpl::startTraining(trainData, flags);
    109         int nvars = w->data->getNVars();
    110         int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
    111         m = std::min(std::max(m, 1), nvars);
    112         allVars.resize(nvars);
    113         activeVars.resize(m);
    114         for( i = 0; i < nvars; i++ )
    115             allVars[i] = varIdx[i];
    116     }
    117 
    118     void endTraining()
    119     {
    120         DTreesImpl::endTraining();
    121         vector<int> a, b;
    122         std::swap(allVars, a);
    123         std::swap(activeVars, b);
    124     }
    125 
    126     bool train( const Ptr<TrainData>& trainData, int flags )
    127     {
    128         startTraining(trainData, flags);
    129         int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
    130             rparams.termCrit.maxCount : 10000;
    131         int i, j, k, vi, vi_, n = (int)w->sidx.size();
    132         int nclasses = (int)classLabels.size();
    133         double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
    134             rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
    135         vector<int> sidx(n);
    136         vector<uchar> oobmask(n);
    137         vector<int> oobidx;
    138         vector<int> oobperm;
    139         vector<double> oobres(n, 0.);
    140         vector<int> oobcount(n, 0);
    141         vector<int> oobvotes(n*nclasses, 0);
    142         int nvars = w->data->getNVars();
    143         int nallvars = w->data->getNAllVars();
    144         const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
    145         vector<float> samplebuf(nallvars);
    146         Mat samples = w->data->getSamples();
    147         float* psamples = samples.ptr<float>();
    148         size_t sstep0 = samples.step1(), sstep1 = 1;
    149         Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
    150         int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
    151 
    152         bool calcOOBError = eps > 0 || rparams.calcVarImportance;
    153         double max_response = 0.;
    154 
    155         if( w->data->getLayout() == COL_SAMPLE )
    156             std::swap(sstep0, sstep1);
    157 
    158         if( !_isClassifier )
    159         {
    160             for( i = 0; i < n; i++ )
    161             {
    162                 double val = std::abs(w->ord_responses[w->sidx[i]]);
    163                 max_response = std::max(max_response, val);
    164             }
    165         }
    166 
    167         if( rparams.calcVarImportance )
    168             varImportance.resize(nallvars, 0.f);
    169 
    170         for( treeidx = 0; treeidx < ntrees; treeidx++ )
    171         {
    172             for( i = 0; i < n; i++ )
    173                 oobmask[i] = (uchar)1;
    174 
    175             for( i = 0; i < n; i++ )
    176             {
    177                 j = rng.uniform(0, n);
    178                 sidx[i] = w->sidx[j];
    179                 oobmask[j] = (uchar)0;
    180             }
    181             int root = addTree( sidx );
    182             if( root < 0 )
    183                 return false;
    184 
    185             if( calcOOBError )
    186             {
    187                 oobidx.clear();
    188                 for( i = 0; i < n; i++ )
    189                 {
    190                     if( !oobmask[i] )
    191                         oobidx.push_back(i);
    192                 }
    193                 int n_oob = (int)oobidx.size();
    194                 // if there is no out-of-bag samples, we can not compute OOB error
    195                 // nor update the variable importance vector; so we proceed to the next tree
    196                 if( n_oob == 0 )
    197                     continue;
    198                 double ncorrect_responses = 0.;
    199 
    200                 oobError = 0.;
    201                 for( i = 0; i < n_oob; i++ )
    202                 {
    203                     j = oobidx[i];
    204                     sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
    205 
    206                     double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
    207                     if( !_isClassifier )
    208                     {
    209                         oobres[j] += val;
    210                         oobcount[j]++;
    211                         double true_val = w->ord_responses[w->sidx[j]];
    212                         double a = oobres[j]/oobcount[j] - true_val;
    213                         oobError += a*a;
    214                         val = (val - true_val)/max_response;
    215                         ncorrect_responses += std::exp( -val*val );
    216                     }
    217                     else
    218                     {
    219                         int ival = cvRound(val);
    220                         int* votes = &oobvotes[j*nclasses];
    221                         votes[ival]++;
    222                         int best_class = 0;
    223                         for( k = 1; k < nclasses; k++ )
    224                             if( votes[best_class] < votes[k] )
    225                                 best_class = k;
    226                         int diff = best_class != w->cat_responses[w->sidx[j]];
    227                         oobError += diff;
    228                         ncorrect_responses += diff == 0;
    229                     }
    230                 }
    231 
    232                 oobError /= n_oob;
    233                 if( rparams.calcVarImportance && n_oob > 1 )
    234                 {
    235                     oobperm.resize(n_oob);
    236                     for( i = 0; i < n_oob; i++ )
    237                         oobperm[i] = oobidx[i];
    238 
    239                     for( vi_ = 0; vi_ < nvars; vi_++ )
    240                     {
    241                         vi = vidx ? vidx[vi_] : vi_;
    242                         double ncorrect_responses_permuted = 0;
    243                         for( i = 0; i < n_oob; i++ )
    244                         {
    245                             int i1 = rng.uniform(0, n_oob);
    246                             int i2 = rng.uniform(0, n_oob);
    247                             std::swap(i1, i2);
    248                         }
    249 
    250                         for( i = 0; i < n_oob; i++ )
    251                         {
    252                             j = oobidx[i];
    253                             int vj = oobperm[i];
    254                             sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
    255                             for( k = 0; k < nallvars; k++ )
    256                                 sample.at<float>(k) = sample0.at<float>(k);
    257                             sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
    258 
    259                             double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
    260                             if( !_isClassifier )
    261                             {
    262                                 val = (val - w->ord_responses[w->sidx[j]])/max_response;
    263                                 ncorrect_responses_permuted += exp( -val*val );
    264                             }
    265                             else
    266                                 ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
    267                         }
    268                         varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
    269                     }
    270                 }
    271             }
    272             if( calcOOBError && oobError < eps )
    273                 break;
    274         }
    275 
    276         if( rparams.calcVarImportance )
    277         {
    278             for( vi_ = 0; vi_ < nallvars; vi_++ )
    279                 varImportance[vi_] = std::max(varImportance[vi_], 0.f);
    280             normalize(varImportance, varImportance, 1., 0, NORM_L1);
    281         }
    282         endTraining();
    283         return true;
    284     }
    285 
    286     void writeTrainingParams( FileStorage& fs ) const
    287     {
    288         DTreesImpl::writeTrainingParams(fs);
    289         fs << "nactive_vars" << rparams.nactiveVars;
    290     }
    291 
    292     void write( FileStorage& fs ) const
    293     {
    294         if( roots.empty() )
    295             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
    296 
    297         writeParams(fs);
    298 
    299         fs << "oob_error" << oobError;
    300         if( !varImportance.empty() )
    301             fs << "var_importance" << varImportance;
    302 
    303         int k, ntrees = (int)roots.size();
    304 
    305         fs << "ntrees" << ntrees
    306            << "trees" << "[";
    307 
    308         for( k = 0; k < ntrees; k++ )
    309         {
    310             fs << "{";
    311             writeTree(fs, roots[k]);
    312             fs << "}";
    313         }
    314 
    315         fs << "]";
    316     }
    317 
    318     void readParams( const FileNode& fn )
    319     {
    320         DTreesImpl::readParams(fn);
    321 
    322         FileNode tparams_node = fn["training_params"];
    323         rparams.nactiveVars = (int)tparams_node["nactive_vars"];
    324     }
    325 
    326     void read( const FileNode& fn )
    327     {
    328         clear();
    329 
    330         //int nclasses = (int)fn["nclasses"];
    331         //int nsamples = (int)fn["nsamples"];
    332         oobError = (double)fn["oob_error"];
    333         int ntrees = (int)fn["ntrees"];
    334 
    335         readVectorOrMat(fn["var_importance"], varImportance);
    336 
    337         readParams(fn);
    338 
    339         FileNode trees_node = fn["trees"];
    340         FileNodeIterator it = trees_node.begin();
    341         CV_Assert( ntrees == (int)trees_node.size() );
    342 
    343         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
    344         {
    345             FileNode nfn = (*it)["nodes"];
    346             readTree(nfn);
    347         }
    348     }
    349 
    350     RTreeParams rparams;
    351     double oobError;
    352     vector<float> varImportance;
    353     vector<int> allVars, activeVars;
    354     RNG rng;
    355 };
    356 
    357 
    358 class RTreesImpl : public RTrees
    359 {
    360 public:
    361     CV_IMPL_PROPERTY(bool, CalculateVarImportance, impl.rparams.calcVarImportance)
    362     CV_IMPL_PROPERTY(int, ActiveVarCount, impl.rparams.nactiveVars)
    363     CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, impl.rparams.termCrit)
    364 
    365     CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
    366     CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
    367     CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
    368     CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
    369     CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
    370     CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
    371     CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
    372     CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
    373     CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
    374 
    375     RTreesImpl() {}
    376     virtual ~RTreesImpl() {}
    377 
    378     String getDefaultName() const { return "opencv_ml_rtrees"; }
    379 
    380     bool train( const Ptr<TrainData>& trainData, int flags )
    381     {
    382         return impl.train(trainData, flags);
    383     }
    384 
    385     float predict( InputArray samples, OutputArray results, int flags ) const
    386     {
    387         return impl.predict(samples, results, flags);
    388     }
    389 
    390     void write( FileStorage& fs ) const
    391     {
    392         impl.write(fs);
    393     }
    394 
    395     void read( const FileNode& fn )
    396     {
    397         impl.read(fn);
    398     }
    399 
    400     Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
    401     int getVarCount() const { return impl.getVarCount(); }
    402 
    403     bool isTrained() const { return impl.isTrained(); }
    404     bool isClassifier() const { return impl.isClassifier(); }
    405 
    406     const vector<int>& getRoots() const { return impl.getRoots(); }
    407     const vector<Node>& getNodes() const { return impl.getNodes(); }
    408     const vector<Split>& getSplits() const { return impl.getSplits(); }
    409     const vector<int>& getSubsets() const { return impl.getSubsets(); }
    410 
    411     DTreesImplForRTrees impl;
    412 };
    413 
    414 
    415 Ptr<RTrees> RTrees::create()
    416 {
    417     return makePtr<RTreesImpl>();
    418 }
    419 
    420 }}
    421 
    422 // End of file.
    423