Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                           License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     15 // Third party copyrights are property of their respective owners.
     16 //
     17 // Redistribution and use in source and binary forms, with or without modification,
     18 // are permitted provided that the following conditions are met:
     19 //
     20 //   * Redistribution's of source code must retain the above copyright notice,
     21 //     this list of conditions and the following disclaimer.
     22 //
     23 //   * Redistribution's in binary form must reproduce the above copyright notice,
     24 //     this list of conditions and the following disclaimer in the documentation
     25 //     and/or other materials provided with the distribution.
     26 //
     27 //   * The name of the copyright holders may not be used to endorse or promote products
     28 //     derived from this software without specific prior written permission.
     29 //
     30 // This software is provided by the copyright holders and contributors "as is" and
     31 // any express or implied warranties, including, but not limited to, the implied
     32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     33 // In no event shall the Intel Corporation or contributors be liable for any direct,
     34 // indirect, incidental, special, exemplary, or consequential damages
     35 // (including, but not limited to, procurement of substitute goods or services;
     36 // loss of use, data, or profits; or business interruption) however caused
     37 // and on any theory of liability, whether in contract, strict liability,
     38 // or tort (including negligence or otherwise) arising in any way out of
     39 // the use of this software, even if advised of the possibility of such damage.
     40 //
     41 //M*/
     42 
     43 #include "precomp.hpp"
     44 
     45 namespace cv { namespace ml {
     46 
     47 static inline double
     48 log_ratio( double val )
     49 {
     50     const double eps = 1e-5;
     51     val = std::max( val, eps );
     52     val = std::min( val, 1. - eps );
     53     return log( val/(1. - val) );
     54 }
     55 
     56 
     57 BoostTreeParams::BoostTreeParams()
     58 {
     59     boostType = Boost::REAL;
     60     weakCount = 100;
     61     weightTrimRate = 0.95;
     62 }
     63 
     64 BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
     65                                   double _weightTrimRate)
     66 {
     67     boostType = _boostType;
     68     weakCount = _weak_count;
     69     weightTrimRate = _weightTrimRate;
     70 }
     71 
     72 class DTreesImplForBoost : public DTreesImpl
     73 {
     74 public:
     75     DTreesImplForBoost()
     76     {
     77         params.setCVFolds(0);
     78         params.setMaxDepth(1);
     79     }
     80     virtual ~DTreesImplForBoost() {}
     81 
     82     bool isClassifier() const { return true; }
     83 
     84     void clear()
     85     {
     86         DTreesImpl::clear();
     87     }
     88 
     89     void startTraining( const Ptr<TrainData>& trainData, int flags )
     90     {
     91         DTreesImpl::startTraining(trainData, flags);
     92         sumResult.assign(w->sidx.size(), 0.);
     93 
     94         if( bparams.boostType != Boost::DISCRETE )
     95         {
     96             _isClassifier = false;
     97             int i, n = (int)w->cat_responses.size();
     98             w->ord_responses.resize(n);
     99 
    100             double a = -1, b = 1;
    101             if( bparams.boostType == Boost::LOGIT )
    102             {
    103                 a = -2, b = 2;
    104             }
    105             for( i = 0; i < n; i++ )
    106                 w->ord_responses[i] = w->cat_responses[i] > 0 ? b : a;
    107         }
    108 
    109         normalizeWeights();
    110     }
    111 
    112     void normalizeWeights()
    113     {
    114         int i, n = (int)w->sidx.size();
    115         double sumw = 0, a, b;
    116         for( i = 0; i < n; i++ )
    117             sumw += w->sample_weights[w->sidx[i]];
    118         if( sumw > DBL_EPSILON )
    119         {
    120             a = 1./sumw;
    121             b = 0;
    122         }
    123         else
    124         {
    125             a = 0;
    126             b = 1;
    127         }
    128         for( i = 0; i < n; i++ )
    129         {
    130             double& wval = w->sample_weights[w->sidx[i]];
    131             wval = wval*a + b;
    132         }
    133     }
    134 
    135     void endTraining()
    136     {
    137         DTreesImpl::endTraining();
    138         vector<double> e;
    139         std::swap(sumResult, e);
    140     }
    141 
    142     void scaleTree( int root, double scale )
    143     {
    144         int nidx = root, pidx = 0;
    145         Node *node = 0;
    146 
    147         // traverse the tree and save all the nodes in depth-first order
    148         for(;;)
    149         {
    150             for(;;)
    151             {
    152                 node = &nodes[nidx];
    153                 node->value *= scale;
    154                 if( node->left < 0 )
    155                     break;
    156                 nidx = node->left;
    157             }
    158 
    159             for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
    160                  nidx = pidx, pidx = nodes[pidx].parent )
    161                 ;
    162 
    163             if( pidx < 0 )
    164                 break;
    165 
    166             nidx = nodes[pidx].right;
    167         }
    168     }
    169 
    170     void calcValue( int nidx, const vector<int>& _sidx )
    171     {
    172         DTreesImpl::calcValue(nidx, _sidx);
    173         WNode* node = &w->wnodes[nidx];
    174         if( bparams.boostType == Boost::DISCRETE )
    175         {
    176             node->value = node->class_idx == 0 ? -1 : 1;
    177         }
    178         else if( bparams.boostType == Boost::REAL )
    179         {
    180             double p = (node->value+1)*0.5;
    181             node->value = 0.5*log_ratio(p);
    182         }
    183     }
    184 
    185     bool train( const Ptr<TrainData>& trainData, int flags )
    186     {
    187         startTraining(trainData, flags);
    188         int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
    189         vector<int> sidx = w->sidx;
    190 
    191         for( treeidx = 0; treeidx < ntrees; treeidx++ )
    192         {
    193             int root = addTree( sidx );
    194             if( root < 0 )
    195                 return false;
    196             updateWeightsAndTrim( treeidx, sidx );
    197         }
    198         endTraining();
    199         return true;
    200     }
    201 
    202     void updateWeightsAndTrim( int treeidx, vector<int>& sidx )
    203     {
    204         int i, n = (int)w->sidx.size();
    205         int nvars = (int)varIdx.size();
    206         double sumw = 0., C = 1.;
    207         cv::AutoBuffer<double> buf(n + nvars);
    208         double* result = buf;
    209         float* sbuf = (float*)(result + n);
    210         Mat sample(1, nvars, CV_32F, sbuf);
    211         int predictFlags = bparams.boostType == Boost::DISCRETE ? (PREDICT_MAX_VOTE | RAW_OUTPUT) : PREDICT_SUM;
    212         predictFlags |= COMPRESSED_INPUT;
    213 
    214         for( i = 0; i < n; i++ )
    215         {
    216             w->data->getSample(varIdx, w->sidx[i], sbuf );
    217             result[i] = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
    218         }
    219 
    220         // now update weights and other parameters for each type of boosting
    221         if( bparams.boostType == Boost::DISCRETE )
    222         {
    223             // Discrete AdaBoost:
    224             //   weak_eval[i] (=f(x_i)) is in {-1,1}
    225             //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
    226             //   C = log((1-err)/err)
    227             //   w_i *= exp(C*(f(x_i) != y_i))
    228             double err = 0.;
    229 
    230             for( i = 0; i < n; i++ )
    231             {
    232                 int si = w->sidx[i];
    233                 double wval = w->sample_weights[si];
    234                 sumw += wval;
    235                 err += wval*(result[i] != w->cat_responses[si]);
    236             }
    237 
    238             if( sumw != 0 )
    239                 err /= sumw;
    240             C = -log_ratio( err );
    241             double scale = std::exp(C);
    242 
    243             sumw = 0;
    244             for( i = 0; i < n; i++ )
    245             {
    246                 int si = w->sidx[i];
    247                 double wval = w->sample_weights[si];
    248                 if( result[i] != w->cat_responses[si] )
    249                     wval *= scale;
    250                 sumw += wval;
    251                 w->sample_weights[si] = wval;
    252             }
    253 
    254             scaleTree(roots[treeidx], C);
    255         }
    256         else if( bparams.boostType == Boost::REAL || bparams.boostType == Boost::GENTLE )
    257         {
    258             // Real AdaBoost:
    259             //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
    260             //   w_i *= exp(-y_i*f(x_i))
    261 
    262             // Gentle AdaBoost:
    263             //   weak_eval[i] = f(x_i) in [-1,1]
    264             //   w_i *= exp(-y_i*f(x_i))
    265             for( i = 0; i < n; i++ )
    266             {
    267                 int si = w->sidx[i];
    268                 CV_Assert( std::abs(w->ord_responses[si]) == 1 );
    269                 double wval = w->sample_weights[si]*std::exp(-result[i]*w->ord_responses[si]);
    270                 sumw += wval;
    271                 w->sample_weights[si] = wval;
    272             }
    273         }
    274         else if( bparams.boostType == Boost::LOGIT )
    275         {
    276             // LogitBoost:
    277             //   weak_eval[i] = f(x_i) in [-z_max,z_max]
    278             //   sum_response = F(x_i).
    279             //   F(x_i) += 0.5*f(x_i)
    280             //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
    281             //   reuse weak_eval: weak_eval[i] <- p(x_i)
    282             //   w_i = p(x_i)*1(1 - p(x_i))
    283             //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
    284             //   store z_i to the data->data_root as the new target responses
    285             const double lb_weight_thresh = FLT_EPSILON;
    286             const double lb_z_max = 10.;
    287 
    288             for( i = 0; i < n; i++ )
    289             {
    290                 int si = w->sidx[i];
    291                 sumResult[i] += 0.5*result[i];
    292                 double p = 1./(1 + std::exp(-2*sumResult[i]));
    293                 double wval = std::max( p*(1 - p), lb_weight_thresh ), z;
    294                 w->sample_weights[si] = wval;
    295                 sumw += wval;
    296                 if( w->ord_responses[si] > 0 )
    297                 {
    298                     z = 1./p;
    299                     w->ord_responses[si] = std::min(z, lb_z_max);
    300                 }
    301                 else
    302                 {
    303                     z = 1./(1-p);
    304                     w->ord_responses[si] = -std::min(z, lb_z_max);
    305                 }
    306             }
    307         }
    308         else
    309             CV_Error(CV_StsNotImplemented, "Unknown boosting type");
    310 
    311         /*if( bparams.boostType != Boost::LOGIT )
    312         {
    313             double err = 0;
    314             for( i = 0; i < n; i++ )
    315             {
    316                 sumResult[i] += result[i]*C;
    317                 if( bparams.boostType != Boost::DISCRETE )
    318                     err += sumResult[i]*w->ord_responses[w->sidx[i]] < 0;
    319                 else
    320                     err += sumResult[i]*w->cat_responses[w->sidx[i]] < 0;
    321             }
    322             printf("%d trees. C=%.2f, training error=%.1f%%, working set size=%d (out of %d)\n", (int)roots.size(), C, err*100./n, (int)sidx.size(), n);
    323         }*/
    324 
    325         // renormalize weights
    326         if( sumw > FLT_EPSILON )
    327             normalizeWeights();
    328 
    329         if( bparams.weightTrimRate <= 0. || bparams.weightTrimRate >= 1. )
    330             return;
    331 
    332         for( i = 0; i < n; i++ )
    333             result[i] = w->sample_weights[w->sidx[i]];
    334         std::sort(result, result + n);
    335 
    336         // as weight trimming occurs immediately after updating the weights,
    337         // where they are renormalized, we assume that the weight sum = 1.
    338         sumw = 1. - bparams.weightTrimRate;
    339 
    340         for( i = 0; i < n; i++ )
    341         {
    342             double wval = result[i];
    343             if( sumw <= 0 )
    344                 break;
    345             sumw -= wval;
    346         }
    347 
    348         double threshold = i < n ? result[i] : DBL_MAX;
    349         sidx.clear();
    350 
    351         for( i = 0; i < n; i++ )
    352         {
    353             int si = w->sidx[i];
    354             if( w->sample_weights[si] >= threshold )
    355                 sidx.push_back(si);
    356         }
    357     }
    358 
    359     float predictTrees( const Range& range, const Mat& sample, int flags0 ) const
    360     {
    361         int flags = (flags0 & ~PREDICT_MASK) | PREDICT_SUM;
    362         float val = DTreesImpl::predictTrees(range, sample, flags);
    363         if( flags != flags0 )
    364         {
    365             int ival = (int)(val > 0);
    366             if( !(flags0 & RAW_OUTPUT) )
    367                 ival = classLabels[ival];
    368             val = (float)ival;
    369         }
    370         return val;
    371     }
    372 
    373     void writeTrainingParams( FileStorage& fs ) const
    374     {
    375         fs << "boosting_type" <<
    376         (bparams.boostType == Boost::DISCRETE ? "DiscreteAdaboost" :
    377         bparams.boostType == Boost::REAL ? "RealAdaboost" :
    378         bparams.boostType == Boost::LOGIT ? "LogitBoost" :
    379         bparams.boostType == Boost::GENTLE ? "GentleAdaboost" : "Unknown");
    380 
    381         DTreesImpl::writeTrainingParams(fs);
    382         fs << "weight_trimming_rate" << bparams.weightTrimRate;
    383     }
    384 
    385     void write( FileStorage& fs ) const
    386     {
    387         if( roots.empty() )
    388             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
    389 
    390         writeParams(fs);
    391 
    392         int k, ntrees = (int)roots.size();
    393 
    394         fs << "ntrees" << ntrees
    395         << "trees" << "[";
    396 
    397         for( k = 0; k < ntrees; k++ )
    398         {
    399             fs << "{";
    400             writeTree(fs, roots[k]);
    401             fs << "}";
    402         }
    403 
    404         fs << "]";
    405     }
    406 
    407     void readParams( const FileNode& fn )
    408     {
    409         DTreesImpl::readParams(fn);
    410 
    411         FileNode tparams_node = fn["training_params"];
    412         // check for old layout
    413         String bts = (String)(fn["boosting_type"].empty() ?
    414                          tparams_node["boosting_type"] : fn["boosting_type"]);
    415         bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
    416                              bts == "RealAdaboost" ? Boost::REAL :
    417                              bts == "LogitBoost" ? Boost::LOGIT :
    418                              bts == "GentleAdaboost" ? Boost::GENTLE : -1);
    419         _isClassifier = bparams.boostType == Boost::DISCRETE;
    420         // check for old layout
    421         bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
    422                                     tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
    423     }
    424 
    425     void read( const FileNode& fn )
    426     {
    427         clear();
    428 
    429         int ntrees = (int)fn["ntrees"];
    430         readParams(fn);
    431 
    432         FileNode trees_node = fn["trees"];
    433         FileNodeIterator it = trees_node.begin();
    434         CV_Assert( ntrees == (int)trees_node.size() );
    435 
    436         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
    437         {
    438             FileNode nfn = (*it)["nodes"];
    439             readTree(nfn);
    440         }
    441     }
    442 
    443     BoostTreeParams bparams;
    444     vector<double> sumResult;
    445 };
    446 
    447 
    448 class BoostImpl : public Boost
    449 {
    450 public:
    451     BoostImpl() {}
    452     virtual ~BoostImpl() {}
    453 
    454     CV_IMPL_PROPERTY(int, BoostType, impl.bparams.boostType)
    455     CV_IMPL_PROPERTY(int, WeakCount, impl.bparams.weakCount)
    456     CV_IMPL_PROPERTY(double, WeightTrimRate, impl.bparams.weightTrimRate)
    457 
    458     CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
    459     CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
    460     CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
    461     CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
    462     CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
    463     CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
    464     CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
    465     CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
    466     CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
    467 
    468     String getDefaultName() const { return "opencv_ml_boost"; }
    469 
    470     bool train( const Ptr<TrainData>& trainData, int flags )
    471     {
    472         return impl.train(trainData, flags);
    473     }
    474 
    475     float predict( InputArray samples, OutputArray results, int flags ) const
    476     {
    477         return impl.predict(samples, results, flags);
    478     }
    479 
    480     void write( FileStorage& fs ) const
    481     {
    482         impl.write(fs);
    483     }
    484 
    485     void read( const FileNode& fn )
    486     {
    487         impl.read(fn);
    488     }
    489 
    490     int getVarCount() const { return impl.getVarCount(); }
    491 
    492     bool isTrained() const { return impl.isTrained(); }
    493     bool isClassifier() const { return impl.isClassifier(); }
    494 
    495     const vector<int>& getRoots() const { return impl.getRoots(); }
    496     const vector<Node>& getNodes() const { return impl.getNodes(); }
    497     const vector<Split>& getSplits() const { return impl.getSplits(); }
    498     const vector<int>& getSubsets() const { return impl.getSubsets(); }
    499 
    500     DTreesImplForBoost impl;
    501 };
    502 
    503 
    504 Ptr<Boost> Boost::create()
    505 {
    506     return makePtr<BoostImpl>();
    507 }
    508 
    509 }}
    510 
    511 /* End of file. */
    512