Home | History | Annotate | Download | only in ml
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #ifndef __OPENCV_ML_PRECOMP_HPP__
     42 #define __OPENCV_ML_PRECOMP_HPP__
     43 
     44 #include "opencv2/core.hpp"
     45 #include "opencv2/ml.hpp"
     46 #include "opencv2/core/core_c.h"
     47 #include "opencv2/core/utility.hpp"
     48 
     49 #include "opencv2/core/private.hpp"
     50 
     51 #include <assert.h>
     52 #include <float.h>
     53 #include <limits.h>
     54 #include <math.h>
     55 #include <stdlib.h>
     56 #include <stdio.h>
     57 #include <string.h>
     58 #include <time.h>
     59 #include <vector>
     60 
     61 /****************************************************************************************\
     62  *                               Main struct definitions                                  *
     63  \****************************************************************************************/
     64 
     65 /* log(2*PI) */
     66 #define CV_LOG2PI (1.8378770664093454835606594728112)
     67 
     68 namespace cv
     69 {
     70 namespace ml
     71 {
     72     using std::vector;
     73 
     74     #define CV_DTREE_CAT_DIR(idx,subset) \
     75         (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
     76 
     77     template<typename _Tp> struct cmp_lt_idx
     78     {
     79         cmp_lt_idx(const _Tp* _arr) : arr(_arr) {}
     80         bool operator ()(int a, int b) const { return arr[a] < arr[b]; }
     81         const _Tp* arr;
     82     };
     83 
     84     template<typename _Tp> struct cmp_lt_ptr
     85     {
     86         cmp_lt_ptr() {}
     87         bool operator ()(const _Tp* a, const _Tp* b) const { return *a < *b; }
     88     };
     89 
     90     static inline void setRangeVector(std::vector<int>& vec, int n)
     91     {
     92         vec.resize(n);
     93         for( int i = 0; i < n; i++ )
     94             vec[i] = i;
     95     }
     96 
     97     static inline void writeTermCrit(FileStorage& fs, const TermCriteria& termCrit)
     98     {
     99         if( (termCrit.type & TermCriteria::EPS) != 0 )
    100             fs << "epsilon" << termCrit.epsilon;
    101         if( (termCrit.type & TermCriteria::COUNT) != 0 )
    102             fs << "iterations" << termCrit.maxCount;
    103     }
    104 
    105     static inline TermCriteria readTermCrit(const FileNode& fn)
    106     {
    107         TermCriteria termCrit;
    108         double epsilon = (double)fn["epsilon"];
    109         if( epsilon > 0 )
    110         {
    111             termCrit.type |= TermCriteria::EPS;
    112             termCrit.epsilon = epsilon;
    113         }
    114         int iters = (int)fn["iterations"];
    115         if( iters > 0 )
    116         {
    117             termCrit.type |= TermCriteria::COUNT;
    118             termCrit.maxCount = iters;
    119         }
    120         return termCrit;
    121     }
    122 
    123     struct TreeParams
    124     {
    125         TreeParams();
    126         TreeParams( int maxDepth, int minSampleCount,
    127                     double regressionAccuracy, bool useSurrogates,
    128                     int maxCategories, int CVFolds,
    129                     bool use1SERule, bool truncatePrunedTree,
    130                     const Mat& priors );
    131 
    132         inline void setMaxCategories(int val)
    133         {
    134             if( val < 2 )
    135                 CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" );
    136             maxCategories = std::min(val, 15 );
    137         }
    138         inline void setMaxDepth(int val)
    139         {
    140             if( val < 0 )
    141                 CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" );
    142             maxDepth = std::min( val, 25 );
    143         }
    144         inline void setMinSampleCount(int val)
    145         {
    146             minSampleCount = std::max(val, 1);
    147         }
    148         inline void setCVFolds(int val)
    149         {
    150             if( val < 0 )
    151                 CV_Error( CV_StsOutOfRange,
    152                           "params.CVFolds should be =0 (the tree is not pruned) "
    153                           "or n>0 (tree is pruned using n-fold cross-validation)" );
    154             if( val == 1 )
    155                 val = 0;
    156             CVFolds = val;
    157         }
    158         inline void setRegressionAccuracy(float val)
    159         {
    160             if( val < 0 )
    161                 CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
    162             regressionAccuracy = val;
    163         }
    164 
    165         inline int getMaxCategories() const { return maxCategories; }
    166         inline int getMaxDepth() const { return maxDepth; }
    167         inline int getMinSampleCount() const { return minSampleCount; }
    168         inline int getCVFolds() const { return CVFolds; }
    169         inline float getRegressionAccuracy() const { return regressionAccuracy; }
    170 
    171         CV_IMPL_PROPERTY(bool, UseSurrogates, useSurrogates)
    172         CV_IMPL_PROPERTY(bool, Use1SERule, use1SERule)
    173         CV_IMPL_PROPERTY(bool, TruncatePrunedTree, truncatePrunedTree)
    174         CV_IMPL_PROPERTY_S(cv::Mat, Priors, priors)
    175 
    176         public:
    177             bool  useSurrogates;
    178         bool  use1SERule;
    179         bool  truncatePrunedTree;
    180         Mat priors;
    181 
    182     protected:
    183         int   maxCategories;
    184         int   maxDepth;
    185         int   minSampleCount;
    186         int   CVFolds;
    187         float regressionAccuracy;
    188     };
    189 
    190     struct RTreeParams
    191     {
    192         RTreeParams();
    193         RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit );
    194         bool calcVarImportance;
    195         int nactiveVars;
    196         TermCriteria termCrit;
    197     };
    198 
    199     struct BoostTreeParams
    200     {
    201         BoostTreeParams();
    202         BoostTreeParams(int boostType, int weakCount, double weightTrimRate);
    203         int boostType;
    204         int weakCount;
    205         double weightTrimRate;
    206     };
    207 
    208     class DTreesImpl : public DTrees
    209     {
    210     public:
    211         struct WNode
    212         {
    213             WNode()
    214             {
    215                 class_idx = sample_count = depth = complexity = 0;
    216                 parent = left = right = split = defaultDir = -1;
    217                 Tn = INT_MAX;
    218                 value = maxlr = alpha = node_risk = tree_risk = tree_error = 0.;
    219             }
    220 
    221             int class_idx;
    222             double Tn;
    223             double value;
    224 
    225             int parent;
    226             int left;
    227             int right;
    228             int defaultDir;
    229 
    230             int split;
    231 
    232             int sample_count;
    233             int depth;
    234             double maxlr;
    235 
    236             // global pruning data
    237             int complexity;
    238             double alpha;
    239             double node_risk, tree_risk, tree_error;
    240         };
    241 
    242         struct WSplit
    243         {
    244             WSplit()
    245             {
    246                 varIdx = next = 0;
    247                 inversed = false;
    248                 quality = c = 0.f;
    249                 subsetOfs = -1;
    250             }
    251 
    252             int varIdx;
    253             bool inversed;
    254             float quality;
    255             int next;
    256             float c;
    257             int subsetOfs;
    258         };
    259 
    260         struct WorkData
    261         {
    262             WorkData(const Ptr<TrainData>& _data);
    263 
    264             Ptr<TrainData> data;
    265             vector<WNode> wnodes;
    266             vector<WSplit> wsplits;
    267             vector<int> wsubsets;
    268             vector<double> cv_Tn;
    269             vector<double> cv_node_risk;
    270             vector<double> cv_node_error;
    271             vector<int> cv_labels;
    272             vector<double> sample_weights;
    273             vector<int> cat_responses;
    274             vector<double> ord_responses;
    275             vector<int> sidx;
    276             int maxSubsetSize;
    277         };
    278 
    279         CV_WRAP_SAME_PROPERTY(int, MaxCategories, params)
    280         CV_WRAP_SAME_PROPERTY(int, MaxDepth, params)
    281         CV_WRAP_SAME_PROPERTY(int, MinSampleCount, params)
    282         CV_WRAP_SAME_PROPERTY(int, CVFolds, params)
    283         CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, params)
    284         CV_WRAP_SAME_PROPERTY(bool, Use1SERule, params)
    285         CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, params)
    286         CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, params)
    287         CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, params)
    288 
    289         DTreesImpl();
    290         virtual ~DTreesImpl();
    291         virtual void clear();
    292 
    293         String getDefaultName() const { return "opencv_ml_dtree"; }
    294         bool isTrained() const { return !roots.empty(); }
    295         bool isClassifier() const { return _isClassifier; }
    296         int getVarCount() const { return varType.empty() ? 0 : (int)(varType.size() - 1); }
    297         int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; }
    298         int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; }
    299 
    300         virtual void setDParams(const TreeParams& _params);
    301         virtual void startTraining( const Ptr<TrainData>& trainData, int flags );
    302         virtual void endTraining();
    303         virtual void initCompVarIdx();
    304         virtual bool train( const Ptr<TrainData>& trainData, int flags );
    305 
    306         virtual int addTree( const vector<int>& sidx );
    307         virtual int addNodeAndTrySplit( int parent, const vector<int>& sidx );
    308         virtual const vector<int>& getActiveVars();
    309         virtual int findBestSplit( const vector<int>& _sidx );
    310         virtual void calcValue( int nidx, const vector<int>& _sidx );
    311 
    312         virtual WSplit findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality );
    313 
    314         // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
    315         virtual void clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels );
    316         virtual WSplit findSplitCatClass( int vi, const vector<int>& _sidx, double initQuality, int* subset );
    317 
    318         virtual WSplit findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality );
    319         virtual WSplit findSplitCatReg( int vi, const vector<int>& _sidx, double initQuality, int* subset );
    320 
    321         virtual int calcDir( int splitidx, const vector<int>& _sidx, vector<int>& _sleft, vector<int>& _sright );
    322         virtual int pruneCV( int root );
    323 
    324         virtual double updateTreeRNC( int root, double T, int fold );
    325         virtual bool cutTree( int root, double T, int fold, double min_alpha );
    326         virtual float predictTrees( const Range& range, const Mat& sample, int flags ) const;
    327         virtual float predict( InputArray inputs, OutputArray outputs, int flags ) const;
    328 
    329         virtual void writeTrainingParams( FileStorage& fs ) const;
    330         virtual void writeParams( FileStorage& fs ) const;
    331         virtual void writeSplit( FileStorage& fs, int splitidx ) const;
    332         virtual void writeNode( FileStorage& fs, int nidx, int depth ) const;
    333         virtual void writeTree( FileStorage& fs, int root ) const;
    334         virtual void write( FileStorage& fs ) const;
    335 
    336         virtual void readParams( const FileNode& fn );
    337         virtual int readSplit( const FileNode& fn );
    338         virtual int readNode( const FileNode& fn );
    339         virtual int readTree( const FileNode& fn );
    340         virtual void read( const FileNode& fn );
    341 
    342         virtual const std::vector<int>& getRoots() const { return roots; }
    343         virtual const std::vector<Node>& getNodes() const { return nodes; }
    344         virtual const std::vector<Split>& getSplits() const { return splits; }
    345         virtual const std::vector<int>& getSubsets() const { return subsets; }
    346 
    347         TreeParams params;
    348 
    349         vector<int> varIdx;
    350         vector<int> compVarIdx;
    351         vector<uchar> varType;
    352         vector<Vec2i> catOfs;
    353         vector<int> catMap;
    354         vector<int> roots;
    355         vector<Node> nodes;
    356         vector<Split> splits;
    357         vector<int> subsets;
    358         vector<int> classLabels;
    359         vector<float> missingSubst;
    360         vector<int> varMapping;
    361         bool _isClassifier;
    362 
    363         Ptr<WorkData> w;
    364     };
    365 
    366     template <typename T>
    367     static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
    368     {
    369         if (node.type() == FileNode::MAP)
    370         {
    371             Mat m;
    372             node >> m;
    373             m.copyTo(v);
    374         }
    375         else if (node.type() == FileNode::SEQ)
    376         {
    377             node >> v;
    378         }
    379     }
    380 
    381 }}
    382 
    383 #endif /* __OPENCV_ML_PRECOMP_HPP__ */
    384