1 #ifndef _OPENCV_BOOST_H_ 2 #define _OPENCV_BOOST_H_ 3 4 #include "traincascade_features.h" 5 #include "old_ml.hpp" 6 7 struct CvCascadeBoostParams : CvBoostParams 8 { 9 float minHitRate; 10 float maxFalseAlarm; 11 12 CvCascadeBoostParams(); 13 CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm, 14 double _weightTrimRate, int _maxDepth, int _maxWeakCount ); 15 virtual ~CvCascadeBoostParams() {} 16 void write( cv::FileStorage &fs ) const; 17 bool read( const cv::FileNode &node ); 18 virtual void printDefaults() const; 19 virtual void printAttrs() const; 20 virtual bool scanAttr( const std::string prmName, const std::string val); 21 }; 22 23 struct CvCascadeBoostTrainData : CvDTreeTrainData 24 { 25 CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator, 26 const CvDTreeParams& _params ); 27 CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator, 28 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, 29 const CvDTreeParams& _params = CvDTreeParams() ); 30 virtual void setData( const CvFeatureEvaluator* _featureEvaluator, 31 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, 32 const CvDTreeParams& _params=CvDTreeParams() ); 33 void precalculate(); 34 35 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 36 37 virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf ); 38 virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf); 39 virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf ); 40 41 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf, 42 const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf ); 43 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf ); 44 virtual float getVarValue( int vi, int si ); 45 virtual void free_train_data(); 46 47 const CvFeatureEvaluator* featureEvaluator; 48 cv::Mat valCache; // precalculated feature values (CV_32FC1) 49 CvMat _resp; // for casting 50 int numPrecalcVal, numPrecalcIdx; 51 }; 52 53 class CvCascadeBoostTree : public CvBoostTree 54 { 55 public: 56 virtual CvDTreeNode* predict( int sampleIdx ) const; 57 void write( cv::FileStorage &fs, const cv::Mat& featureMap ); 58 void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data ); 59 void markFeaturesInMap( cv::Mat& featureMap ); 60 protected: 61 virtual void split_node_data( CvDTreeNode* n ); 62 }; 63 64 class CvCascadeBoost : public CvBoost 65 { 66 public: 67 virtual bool train( const CvFeatureEvaluator* _featureEvaluator, 68 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, 69 const CvCascadeBoostParams& _params=CvCascadeBoostParams() ); 70 virtual float predict( int sampleIdx, bool returnSum = false ) const; 71 72 float getThreshold() const { return threshold; } 73 void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const; 74 bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator, 75 const CvCascadeBoostParams& _params ); 76 void markUsedFeaturesInMap( cv::Mat& featureMap ); 77 protected: 78 virtual bool set_params( const CvBoostParams& _params ); 79 virtual void update_weights( CvBoostTree* tree ); 80 virtual bool isErrDesired(); 81 82 float threshold; 83 float minHitRate, maxFalseAlarm; 84 }; 85 86 #endif 87