Home | History | Annotate | Download | only in src

Lines Matching defs:node

271     // allocate root node and the buffer for the whole training data
564 // make a copy of the root node
804 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
806 node->sample_count = count;
807 node->depth = parent ? parent->depth + 1 : 0;
808 node->parent = parent;
809 node->left = node->right = 0;
810 node->split = 0;
811 node->value = 0;
812 node->class_idx = 0;
813 node->maxlr = 0.;
815 node->buf_idx = storage_idx;
816 node->offset = offset;
818 node->num_valid = (int*)cvSetNew( nv_heap );
820 node->num_valid = 0;
821 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
822 node->complexity = 0;
827 node->Tn = INT_MAX;
828 node->cv_Tn = (int*)cvSetNew( cv_heap );
829 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
830 node->cv_node_error = node->cv_node_risk + cv_n;
834 node->Tn = 0;
835 node->cv_Tn = 0;
836 node->cv_node_risk = 0;
837 node->cv_node_error = 0;
840 return node;
875 void CvDTreeTrainData::free_node( CvDTreeNode* node )
877 CvDTreeSplit* split = node->split;
878 free_node_data( node );
885 node->split = 0;
886 cvSetRemoveByPtr( node_heap, node );
890 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
892 if( node->num_valid )
894 cvSetRemoveByPtr( nv_heap, node->num_valid );
895 node->num_valid = 0;
1064 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1074 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1075 var_all = cvReadIntByName( fs, node, "var_all" );
1076 var_count = cvReadIntByName( fs, node, "var_count", var_all );
1077 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1078 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1080 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1116 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1136 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1167 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1168 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1334 void CvDTree::try_split_node( CvDTreeNode* node )
1337 int i, n = node->sample_count, vi;
1341 calc_node_value( node );
1343 if( node->sample_count <= data->params.min_sample_count ||
1344 node->depth >= data->params.max_depth )
1349 // check if we have a "pure" node,
1360 if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1366 best_split = find_best_split(node);
1368 node->split = best_split;
1373 data->free_node_data(node);
1377 quality_scale = calc_node_dir( node );
1392 split = find_surrogate_split_cat( node, vi );
1394 split = find_surrogate_split_ord( node, vi );
1399 CvDTreeSplit* prev_split = node->split;
1411 split_node_data( node );
1412 try_split_node( node->left );
1413 try_split_node( node->right );
1424 // besides, the function compute node->maxlr,
1426 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1428 double CvDTree::calc_node_dir( CvDTreeNode* node )
1431 int i, n = node->sample_count, vi = node->split->var_idx;
1434 assert( !node->split->inversed );
1438 const int* labels = data->get_cat_var_data(node,vi);
1439 const int* subset = node->split->subset;
1458 const int* responses = data->get_class_labels(node);
1477 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1478 int split_point = node->split->ord.split_point;
1479 int n1 = node->get_num_valid(vi);
1497 const int* responses = data->get_class_labels(node);
1522 node->maxlr = MAX( L, R );
1523 return node->split->quality/(L + R);
1527 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1535 if( node->get_num_valid(vi) <= 1 )
1541 split = find_split_cat_class( node, vi );
1543 split = find_split_ord_class( node, vi );
1548 split = find_split_cat_reg( node, vi );
1550 split = find_split_ord_reg( node, vi );
1566 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1569 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1570 const int* responses = data->get_class_labels(node);
1571 int n = node->sample_count;
1572 int n1 = node->get_num_valid(vi);
1756 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1759 const int* labels = data->get_cat_var_data(node, vi);
1760 const int* responses = data->get_class_labels(node);
1762 int n = node->sample_count;
1927 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1930 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1931 const float* responses = data->get_ord_responses(node);
1932 int n = node->sample_count;
1933 int n1 = node->get_num_valid(vi);
1935 double best_val = 0, lsum = 0, rsum = node->value*n;
1967 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1970 const int* labels = data->get_cat_var_data(node, vi);
1971 const float* responses = data->get_ord_responses(node);
1973 int n = node->sample_count;
2048 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
2051 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2053 int n1 = node->get_num_valid(vi);
2064 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2108 double worst_val = node->maxlr;
2111 const int* responses = data->get_class_labels(node);
2156 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2162 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
2164 const int* labels = data->get_cat_var_data(node, vi);
2166 int n = node->sample_count;
2210 const int* responses = data->get_class_labels(node);
2247 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2254 void CvDTree::calc_node_value( CvDTreeNode* node )
2256 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2257 const int* cv_labels = data->get_labels(node);
2262 // * node value is the label of the class that has the largest weight in the node.
2263 // * node risk is the weighted number of misclassified samples,
2271 const int* responses = data->get_class_labels(node);
2303 if( data->have_priors && node->parent == 0 )
2329 node->class_idx = max_k;
2330 node->value = data->cat_map->data.i[
2332 node->node_risk = total_weight - max_val;
2354 node->cv_Tn[j] = INT_MAX;
2355 node->cv_node_risk[j] = sum - max_val;
2356 node->cv_node_error[j] = sum_k - max_val_k;
2362 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2363 // n is the number of samples in the node.
2364 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2370 // where node_value_j is the node value calculated
2375 const float* values = data->get_ord_responses(node);
2419 node->node_risk = sum2 - (sum/n)*sum;
2420 node->value = sum/n;
2428 node->cv_node_risk[j] = s2i - r*r*ci;
2429 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2430 node->cv_Tn[j] = INT_MAX;
2436 void CvDTree::complete_node_dir( CvDTreeNode* node )
2438 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2439 int nz = n - node->get_num_valid(node->split->var_idx);
2445 CvDTreeSplit* split = node->split->next;
2453 const int* labels = data->get_cat_var_data(node, vi);
2470 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2472 int n1 = node->get_num_valid(vi);
2516 void CvDTree::split_node_data( CvDTreeNode* node )
2518 int vi, i, n = node->sample_count, nl, nr;
2522 int new_buf_idx = data->get_child_buf_idx( node );
2531 complete_node_dir(node);
2542 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2543 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2546 split_input_data = node->depth + 1 < data->params.max_depth &&
2547 (node->left->sample_count > data->params.min_sample_count ||
2548 node->right->sample_count > data->params.min_sample_count);
2554 int n1 = node->get_num_valid(vi);
2561 src = data->get_ord_var_data(node, vi);
2601 int n1 = node->get_num_valid(vi), nr1 = 0;
2608 src = data->get_cat_var_data(node, vi);
2632 // deallocate the parent node data that is not needed anymore
2633 data->free_node_data(node);
2739 CvDTreeNode* node = root;
2747 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2748 if( t <= T || !node->left )
2750 node->complexity = 1;
2751 node->tree_risk = node->node_risk;
2752 node->tree_error = 0.;
2755 node->tree_risk = node->cv_node_risk[fold];
2756 node->tree_error = node->cv_node_error[fold];
2760 node = node->left;
2763 for( parent = node->parent; parent && parent->right == node;
2764 node = parent, parent = parent->parent )
2766 parent->complexity += node->complexity;
2767 parent->tree_risk += node->tree_risk;
2768 parent->tree_error += node->tree_error;
2778 parent->complexity = node->complexity;
2779 parent->tree_risk = node->tree_risk;
2780 parent->tree_error = node->tree_error;
2781 node = parent->right;
2790 CvDTreeNode* node = root;
2791 if( !node->left )
2799 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2800 if( t <= T || !node->left )
2802 if( node->alpha <= min_alpha + FLT_EPSILON )
2805 node->cv_Tn[fold] = T;
2807 node->Tn = T;
2808 if( node == root )
2812 node = node->left;
2815 for( parent = node->parent; parent && parent->right == node;
2816 node = parent, parent = parent->parent )
2822 node = parent->right;
2831 CvDTreeNode* node = root;
2838 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2840 node->cv_Tn = 0;
2841 node->cv_node_error = node->cv_node_risk = 0;
2842 if( !node->left )
2844 node = node->left;
2847 for( parent = node->parent; parent && parent->right == node;
2848 node = parent, parent = parent->parent )
2861 node = parent->right;
2894 CvDTreeNode* node = root;
2900 if( !node )
2937 while( node->Tn > pruned_tree_idx && node->left )
2939 CvDTreeSplit* split = node->split;
2994 double diff = node->right->sample_count - node->left->sample_count;
2997 node = dir < 0 ? node->left : node->right;
3000 result = node;
3012 CvDTreeNode* node = root;
3014 if( !node )
3023 for( ;; node = node->left )
3025 CvDTreeSplit* split = node->split;
3027 if( !node->left || node->Tn <= pruned_tree_idx )
3034 for( parent = node->parent; parent && parent->right == node;
3035 node = parent, parent = parent->parent )
3041 node = parent->right;
3088 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
3094 cvWriteInt( fs, "depth", node->depth );
3095 cvWriteInt( fs, "sample_count", node->sample_count );
3096 cvWriteReal( fs, "value", node->value );
3099 cvWriteInt( fs, "norm_class_idx", node->class_idx );
3101 cvWriteInt( fs, "Tn", node->Tn );
3102 cvWriteInt( fs, "complexity", node->complexity );
3103 cvWriteReal( fs, "alpha", node->alpha );
3104 cvWriteReal( fs, "node_risk", node->node_risk );
3105 cvWriteReal( fs, "tree_risk", node->tree_risk );
3106 cvWriteReal( fs, "tree_error", node->tree_error );
3108 if( node->left )
3112 for( split = node->split; split != 0; split = split->next )
3128 CvDTreeNode* node = root;
3136 write_node( fs, node );
3137 if( !node->left )
3139 node = node->left;
3142 for( parent = node->parent; parent && parent->right == node;
3143 node = parent, parent = parent->parent )
3149 node = parent->right;
3282 CvDTreeNode* node = 0;
3294 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3296 if( depth != node->depth )
3297 CV_ERROR( CV_StsParseError, "incorrect node depth" );
3299 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3300 node->value = cvReadRealByName( fs, fnode, "value" );
3302 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3304 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3305 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3306 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3307 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3308 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3309 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3326 node->split = last_split = split;
3336 return node;
3356 CvDTreeNode* node;
3358 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3360 parent->left = node;
3362 parent->right = node;
3363 if( node->split )
3364 parent = node;
3391 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
3402 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
3406 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );