Home | History | Annotate | Download | only in traincascade
      1 #ifndef _OPENCV_BOOST_H_
      2 #define _OPENCV_BOOST_H_
      3 
      4 #include "traincascade_features.h"
      5 #include "old_ml.hpp"
      6 
      7 struct CvCascadeBoostParams : CvBoostParams
      8 {
      9     float minHitRate;
     10     float maxFalseAlarm;
     11 
     12     CvCascadeBoostParams();
     13     CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
     14                           double _weightTrimRate, int _maxDepth, int _maxWeakCount );
     15     virtual ~CvCascadeBoostParams() {}
     16     void write( cv::FileStorage &fs ) const;
     17     bool read( const cv::FileNode &node );
     18     virtual void printDefaults() const;
     19     virtual void printAttrs() const;
     20     virtual bool scanAttr( const std::string prmName, const std::string val);
     21 };
     22 
     23 struct CvCascadeBoostTrainData : CvDTreeTrainData
     24 {
     25     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
     26                              const CvDTreeParams& _params );
     27     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
     28                              int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
     29                              const CvDTreeParams& _params = CvDTreeParams() );
     30     virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
     31                           int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
     32                           const CvDTreeParams& _params=CvDTreeParams() );
     33     void precalculate();
     34 
     35     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
     36 
     37     virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
     38     virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
     39     virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
     40 
     41     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
     42                                   const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
     43     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
     44     virtual float getVarValue( int vi, int si );
     45     virtual void free_train_data();
     46 
     47     const CvFeatureEvaluator* featureEvaluator;
     48     cv::Mat valCache; // precalculated feature values (CV_32FC1)
     49     CvMat _resp; // for casting
     50     int numPrecalcVal, numPrecalcIdx;
     51 };
     52 
     53 class CvCascadeBoostTree : public CvBoostTree
     54 {
     55 public:
     56     virtual CvDTreeNode* predict( int sampleIdx ) const;
     57     void write( cv::FileStorage &fs, const cv::Mat& featureMap );
     58     void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
     59     void markFeaturesInMap( cv::Mat& featureMap );
     60 protected:
     61     virtual void split_node_data( CvDTreeNode* n );
     62 };
     63 
     64 class CvCascadeBoost : public CvBoost
     65 {
     66 public:
     67     virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
     68                         int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
     69                         const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
     70     virtual float predict( int sampleIdx, bool returnSum = false ) const;
     71 
     72     float getThreshold() const { return threshold; }
     73     void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
     74     bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
     75                const CvCascadeBoostParams& _params );
     76     void markUsedFeaturesInMap( cv::Mat& featureMap );
     77 protected:
     78     virtual bool set_params( const CvBoostParams& _params );
     79     virtual void update_weights( CvBoostTree* tree );
     80     virtual bool isErrDesired();
     81 
     82     float threshold;
     83     float minHitRate, maxFalseAlarm;
     84 };
     85 
     86 #endif
     87