Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "precomp.hpp"
     42 #include <ctype.h>
     43 #include <algorithm>
     44 #include <iterator>
     45 
     46 namespace cv { namespace ml {
     47 
     48 static const float MISSED_VAL = TrainData::missingValue();
     49 static const int VAR_MISSED = VAR_ORDERED;
     50 
     51 TrainData::~TrainData() {}
     52 
     53 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
     54 {
     55     if( idx.empty() )
     56         return vec;
     57     int i, j, n = idx.checkVector(1, CV_32S);
     58     int type = vec.type();
     59     CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
     60     int dims = 1, m;
     61 
     62     if( vec.cols == 1 || vec.rows == 1 )
     63     {
     64         dims = 1;
     65         m = vec.cols + vec.rows - 1;
     66     }
     67     else
     68     {
     69         dims = vec.cols;
     70         m = vec.rows;
     71     }
     72 
     73     Mat subvec;
     74 
     75     if( vec.cols == m )
     76         subvec.create(dims, n, type);
     77     else
     78         subvec.create(n, dims, type);
     79     if( type == CV_32S )
     80         for( i = 0; i < n; i++ )
     81         {
     82             int k = idx.at<int>(i);
     83             CV_Assert( 0 <= k && k < m );
     84             if( dims == 1 )
     85                 subvec.at<int>(i) = vec.at<int>(k);
     86             else
     87                 for( j = 0; j < dims; j++ )
     88                     subvec.at<int>(i, j) = vec.at<int>(k, j);
     89         }
     90     else if( type == CV_32F )
     91         for( i = 0; i < n; i++ )
     92         {
     93             int k = idx.at<int>(i);
     94             CV_Assert( 0 <= k && k < m );
     95             if( dims == 1 )
     96                 subvec.at<float>(i) = vec.at<float>(k);
     97             else
     98                 for( j = 0; j < dims; j++ )
     99                     subvec.at<float>(i, j) = vec.at<float>(k, j);
    100         }
    101     else
    102         for( i = 0; i < n; i++ )
    103         {
    104             int k = idx.at<int>(i);
    105             CV_Assert( 0 <= k && k < m );
    106             if( dims == 1 )
    107                 subvec.at<double>(i) = vec.at<double>(k);
    108             else
    109                 for( j = 0; j < dims; j++ )
    110                     subvec.at<double>(i, j) = vec.at<double>(k, j);
    111         }
    112     return subvec;
    113 }
    114 
    115 class TrainDataImpl : public TrainData
    116 {
    117 public:
    118     typedef std::map<String, int> MapType;
    119 
    120     TrainDataImpl()
    121     {
    122         file = 0;
    123         clear();
    124     }
    125 
    126     virtual ~TrainDataImpl() { closeFile(); }
    127 
    128     int getLayout() const { return layout; }
    129     int getNSamples() const
    130     {
    131         return !sampleIdx.empty() ? (int)sampleIdx.total() :
    132                layout == ROW_SAMPLE ? samples.rows : samples.cols;
    133     }
    134     int getNTrainSamples() const
    135     {
    136         return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
    137     }
    138     int getNTestSamples() const
    139     {
    140         return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
    141     }
    142     int getNVars() const
    143     {
    144         return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
    145     }
    146     int getNAllVars() const
    147     {
    148         return layout == ROW_SAMPLE ? samples.cols : samples.rows;
    149     }
    150 
    151     Mat getSamples() const { return samples; }
    152     Mat getResponses() const { return responses; }
    153     Mat getMissing() const { return missing; }
    154     Mat getVarIdx() const { return varIdx; }
    155     Mat getVarType() const { return varType; }
    156     int getResponseType() const
    157     {
    158         return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
    159     }
    160     Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
    161     Mat getTestSampleIdx() const { return testSampleIdx; }
    162     Mat getSampleWeights() const
    163     {
    164         return sampleWeights;
    165     }
    166     Mat getTrainSampleWeights() const
    167     {
    168         return getSubVector(sampleWeights, getTrainSampleIdx());
    169     }
    170     Mat getTestSampleWeights() const
    171     {
    172         Mat idx = getTestSampleIdx();
    173         return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
    174     }
    175     Mat getTrainResponses() const
    176     {
    177         return getSubVector(responses, getTrainSampleIdx());
    178     }
    179     Mat getTrainNormCatResponses() const
    180     {
    181         return getSubVector(normCatResponses, getTrainSampleIdx());
    182     }
    183     Mat getTestResponses() const
    184     {
    185         Mat idx = getTestSampleIdx();
    186         return idx.empty() ? Mat() : getSubVector(responses, idx);
    187     }
    188     Mat getTestNormCatResponses() const
    189     {
    190         Mat idx = getTestSampleIdx();
    191         return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
    192     }
    193     Mat getNormCatResponses() const { return normCatResponses; }
    194     Mat getClassLabels() const { return classLabels; }
    195     Mat getClassCounters() const { return classCounters; }
    196     int getCatCount(int vi) const
    197     {
    198         int n = (int)catOfs.total();
    199         CV_Assert( 0 <= vi && vi < n );
    200         Vec2i ofs = catOfs.at<Vec2i>(vi);
    201         return ofs[1] - ofs[0];
    202     }
    203 
    204     Mat getCatOfs() const { return catOfs; }
    205     Mat getCatMap() const { return catMap; }
    206 
    207     Mat getDefaultSubstValues() const { return missingSubst; }
    208 
    209     void closeFile() { if(file) fclose(file); file=0; }
    210     void clear()
    211     {
    212         closeFile();
    213         samples.release();
    214         missing.release();
    215         varType.release();
    216         responses.release();
    217         sampleIdx.release();
    218         trainSampleIdx.release();
    219         testSampleIdx.release();
    220         normCatResponses.release();
    221         classLabels.release();
    222         classCounters.release();
    223         catMap.release();
    224         catOfs.release();
    225         nameMap = MapType();
    226         layout = ROW_SAMPLE;
    227     }
    228 
    229     typedef std::map<int, int> CatMapHash;
    230 
    231     void setData(InputArray _samples, int _layout, InputArray _responses,
    232                  InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
    233                  InputArray _varType, InputArray _missing)
    234     {
    235         clear();
    236 
    237         CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
    238         samples = _samples.getMat();
    239         layout = _layout;
    240         responses = _responses.getMat();
    241         varIdx = _varIdx.getMat();
    242         sampleIdx = _sampleIdx.getMat();
    243         sampleWeights = _sampleWeights.getMat();
    244         varType = _varType.getMat();
    245         missing = _missing.getMat();
    246 
    247         int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
    248         int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
    249         int i, noutputvars = 0;
    250 
    251         CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
    252 
    253         if( !sampleIdx.empty() )
    254         {
    255             CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
    256                        checkRange(sampleIdx, true, 0, 0, nsamples-1)) ||
    257                        sampleIdx.checkVector(1, CV_8U, true) == nsamples );
    258             if( sampleIdx.type() == CV_8U )
    259                 sampleIdx = convertMaskToIdx(sampleIdx);
    260         }
    261 
    262         if( !sampleWeights.empty() )
    263         {
    264             CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
    265         }
    266         else
    267         {
    268             sampleWeights = Mat::ones(nsamples, 1, CV_32F);
    269         }
    270 
    271         if( !varIdx.empty() )
    272         {
    273             CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
    274                        checkRange(varIdx, true, 0, 0, ninputvars)) ||
    275                        varIdx.checkVector(1, CV_8U, true) == ninputvars );
    276             if( varIdx.type() == CV_8U )
    277                 varIdx = convertMaskToIdx(varIdx);
    278             varIdx = varIdx.clone();
    279             std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
    280         }
    281 
    282         if( !responses.empty() )
    283         {
    284             CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
    285             if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
    286                 noutputvars = 1;
    287             else
    288             {
    289                 CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
    290                            (layout == COL_SAMPLE && responses.cols == nsamples) );
    291                 noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
    292             }
    293             if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
    294             {
    295                 Mat temp;
    296                 transpose(responses, temp);
    297                 responses = temp;
    298             }
    299         }
    300 
    301         int nvars = ninputvars + noutputvars;
    302 
    303         if( !varType.empty() )
    304         {
    305             CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
    306                        checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
    307         }
    308         else
    309         {
    310             varType.create(1, nvars, CV_8U);
    311             varType = Scalar::all(VAR_ORDERED);
    312             if( noutputvars == 1 )
    313                 varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
    314         }
    315 
    316         if( noutputvars > 1 )
    317         {
    318             for( i = 0; i < noutputvars; i++ )
    319                 CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
    320         }
    321 
    322         catOfs = Mat::zeros(1, nvars, CV_32SC2);
    323         missingSubst = Mat::zeros(1, nvars, CV_32F);
    324 
    325         vector<int> labels, counters, sortbuf, tempCatMap;
    326         vector<Vec2i> tempCatOfs;
    327         CatMapHash ofshash;
    328 
    329         AutoBuffer<uchar> buf(nsamples);
    330         Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
    331         bool haveMissing = !missing.empty();
    332         if( haveMissing )
    333         {
    334             CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
    335         }
    336 
    337         // we iterate through all the variables. For each categorical variable we build a map
    338         // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
    339         // often many categorical variables are similar, so we compress the map - try to re-use
    340         // maps for different variables if they are identical
    341         for( i = 0; i < ninputvars; i++ )
    342         {
    343             Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
    344 
    345             if( varType.at<uchar>(i) == VAR_CATEGORICAL )
    346             {
    347                 preprocessCategorical(values_i, 0, labels, 0, sortbuf);
    348                 missingSubst.at<float>(i) = -1.f;
    349                 int j, m = (int)labels.size();
    350                 CV_Assert( m > 0 );
    351                 int a = labels.front(), b = labels.back();
    352                 const int* currmap = &labels[0];
    353                 int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
    354                 CatMapHash::iterator it = ofshash.find(hashval);
    355                 if( it != ofshash.end() )
    356                 {
    357                     int vi = it->second;
    358                     Vec2i ofs0 = tempCatOfs[vi];
    359                     int m0 = ofs0[1] - ofs0[0];
    360                     const int* map0 = &tempCatMap[ofs0[0]];
    361                     if( m0 == m && map0[0] == a && map0[m0-1] == b )
    362                     {
    363                         for( j = 0; j < m; j++ )
    364                             if( map0[j] != currmap[j] )
    365                                 break;
    366                         if( j == m )
    367                         {
    368                             // re-use the map
    369                             tempCatOfs.push_back(ofs0);
    370                             continue;
    371                         }
    372                     }
    373                 }
    374                 else
    375                     ofshash[hashval] = i;
    376                 Vec2i ofs;
    377                 ofs[0] = (int)tempCatMap.size();
    378                 ofs[1] = ofs[0] + m;
    379                 tempCatOfs.push_back(ofs);
    380                 std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
    381             }
    382             else
    383             {
    384                 tempCatOfs.push_back(Vec2i(0, 0));
    385                 /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
    386                 compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
    387                 missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
    388                 missingSubst.at<float>(i) = 0.f;
    389             }
    390         }
    391 
    392         if( !tempCatOfs.empty() )
    393         {
    394             Mat(tempCatOfs).copyTo(catOfs);
    395             Mat(tempCatMap).copyTo(catMap);
    396         }
    397 
    398         if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
    399         {
    400             preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
    401             Mat(labels).copyTo(classLabels);
    402             Mat(counters).copyTo(classCounters);
    403         }
    404     }
    405 
    406     Mat convertMaskToIdx(const Mat& mask)
    407     {
    408         int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
    409         Mat idx(1, nz, CV_32S);
    410         for( i = j = 0; i < n; i++ )
    411             if( mask.at<uchar>(i) )
    412                 idx.at<int>(j++) = i;
    413         return idx;
    414     }
    415 
    416     struct CmpByIdx
    417     {
    418         CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
    419         bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
    420         const int* data;
    421         int step;
    422     };
    423 
    424     void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
    425                                vector<int>* counters, vector<int>& sortbuf)
    426     {
    427         CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
    428         int* odata = 0;
    429         int ostep = 0;
    430 
    431         if(normdata)
    432         {
    433             normdata->create(data.size(), CV_32S);
    434             odata = normdata->ptr<int>();
    435             ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
    436         }
    437 
    438         int i, n = data.cols + data.rows - 1;
    439         sortbuf.resize(n*2);
    440         int* idx = &sortbuf[0];
    441         int* idata = (int*)data.ptr<int>();
    442         int istep = data.isContinuous() ? 1 : (int)data.step1();
    443 
    444         if( data.type() == CV_32F )
    445         {
    446             idata = idx + n;
    447             const float* fdata = data.ptr<float>();
    448             for( i = 0; i < n; i++ )
    449             {
    450                 if( fdata[i*istep] == MISSED_VAL )
    451                     idata[i] = -1;
    452                 else
    453                 {
    454                     idata[i] = cvRound(fdata[i*istep]);
    455                     CV_Assert( (float)idata[i] == fdata[i*istep] );
    456                 }
    457             }
    458             istep = 1;
    459         }
    460 
    461         for( i = 0; i < n; i++ )
    462             idx[i] = i;
    463 
    464         std::sort(idx, idx + n, CmpByIdx(idata, istep));
    465 
    466         int clscount = 1;
    467         for( i = 1; i < n; i++ )
    468             clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
    469 
    470         int clslabel = -1;
    471         int prev = ~idata[idx[0]*istep];
    472         int previdx = 0;
    473 
    474         labels.resize(clscount);
    475         if(counters)
    476             counters->resize(clscount);
    477 
    478         for( i = 0; i < n; i++ )
    479         {
    480             int l = idata[idx[i]*istep];
    481             if( l != prev )
    482             {
    483                 clslabel++;
    484                 labels[clslabel] = l;
    485                 int k = i - previdx;
    486                 if( clslabel > 0 && counters )
    487                     counters->at(clslabel-1) = k;
    488                 prev = l;
    489                 previdx = i;
    490             }
    491             if(odata)
    492                 odata[idx[i]*ostep] = clslabel;
    493         }
    494         if(counters)
    495             counters->at(clslabel) = i - previdx;
    496     }
    497 
    498     bool loadCSV(const String& filename, int headerLines,
    499                  int responseStartIdx, int responseEndIdx,
    500                  const String& varTypeSpec, char delimiter, char missch)
    501     {
    502         const int M = 1000000;
    503         const char delimiters[3] = { ' ', delimiter, '\0' };
    504         int nvars = 0;
    505         bool varTypesSet = false;
    506 
    507         clear();
    508 
    509         file = fopen( filename.c_str(), "rt" );
    510 
    511         if( !file )
    512             return false;
    513 
    514         std::vector<char> _buf(M);
    515         std::vector<float> allresponses;
    516         std::vector<float> rowvals;
    517         std::vector<uchar> vtypes, rowtypes;
    518         bool haveMissed = false;
    519         char* buf = &_buf[0];
    520 
    521         int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
    522         int ninputvars = 0, noutputvars = 0;
    523 
    524         Mat tempSamples, tempMissing, tempResponses;
    525         MapType tempNameMap;
    526         int catCounter = 1;
    527 
    528         // skip header lines
    529         int lineno = 0;
    530         for(;;lineno++)
    531         {
    532             if( !fgets(buf, M, file) )
    533                 break;
    534             if(lineno < headerLines )
    535                 continue;
    536             // trim trailing spaces
    537             int idx = (int)strlen(buf)-1;
    538             while( idx >= 0 && isspace(buf[idx]) )
    539                 buf[idx--] = '\0';
    540             // skip spaces in the beginning
    541             char* ptr = buf;
    542             while( *ptr != '\0' && isspace(*ptr) )
    543                 ptr++;
    544             // skip commented off lines
    545             if(*ptr == '#')
    546                 continue;
    547             rowvals.clear();
    548             rowtypes.clear();
    549 
    550             char* token = strtok(buf, delimiters);
    551             if (!token)
    552                 break;
    553 
    554             for(;;)
    555             {
    556                 float val=0.f; int tp = 0;
    557                 decodeElem( token, val, tp, missch, tempNameMap, catCounter );
    558                 if( tp == VAR_MISSED )
    559                     haveMissed = true;
    560                 rowvals.push_back(val);
    561                 rowtypes.push_back((uchar)tp);
    562                 token = strtok(NULL, delimiters);
    563                 if (!token)
    564                     break;
    565             }
    566 
    567             if( nvars == 0 )
    568             {
    569                 if( rowvals.empty() )
    570                     CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
    571                 nvars = (int)rowvals.size();
    572                 if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
    573                 {
    574                     setVarTypes(varTypeSpec, nvars, vtypes);
    575                     varTypesSet = true;
    576                 }
    577                 else
    578                     vtypes = rowtypes;
    579 
    580                 ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
    581                 ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
    582                 CV_Assert(ridx1 > ridx0);
    583                 noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
    584                 ninputvars = nvars - noutputvars;
    585             }
    586             else
    587                 CV_Assert( nvars == (int)rowvals.size() );
    588 
    589             // check var types
    590             for( i = 0; i < nvars; i++ )
    591             {
    592                 CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
    593                            (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
    594             }
    595 
    596             if( ridx0 >= 0 )
    597             {
    598                 for( i = ridx1; i < nvars; i++ )
    599                     std::swap(rowvals[i], rowvals[i-noutputvars]);
    600                 for( i = ninputvars; i < nvars; i++ )
    601                     allresponses.push_back(rowvals[i]);
    602                 rowvals.pop_back();
    603             }
    604             Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
    605             tempSamples.push_back(rmat);
    606         }
    607 
    608         closeFile();
    609 
    610         int nsamples = tempSamples.rows;
    611         if( nsamples == 0 )
    612             return false;
    613 
    614         if( haveMissed )
    615             compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
    616 
    617         if( ridx0 >= 0 )
    618         {
    619             for( i = ridx1; i < nvars; i++ )
    620                 std::swap(vtypes[i], vtypes[i-noutputvars]);
    621             if( noutputvars > 1 )
    622             {
    623                 for( i = ninputvars; i < nvars; i++ )
    624                     if( vtypes[i] == VAR_CATEGORICAL )
    625                         CV_Error(CV_StsBadArg,
    626                                  "If responses are vector values, not scalars, they must be marked as ordered responses");
    627             }
    628         }
    629 
    630         if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
    631         {
    632             for( i = 0; i < nsamples; i++ )
    633                 if( allresponses[i] != cvRound(allresponses[i]) )
    634                     break;
    635             if( i == nsamples )
    636                 vtypes[ninputvars] = VAR_CATEGORICAL;
    637         }
    638 
    639         Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
    640         setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
    641                 noArray(), Mat(vtypes).clone(), tempMissing);
    642         bool ok = !samples.empty();
    643         if(ok)
    644             std::swap(tempNameMap, nameMap);
    645         return ok;
    646     }
    647 
    648     void decodeElem( const char* token, float& elem, int& type,
    649                      char missch, MapType& namemap, int& counter ) const
    650     {
    651         char* stopstring = NULL;
    652         elem = (float)strtod( token, &stopstring );
    653         if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
    654         {
    655             elem = MISSED_VAL;
    656             type = VAR_MISSED;
    657         }
    658         else if( *stopstring != '\0' )
    659         {
    660             MapType::iterator it = namemap.find(token);
    661             if( it == namemap.end() )
    662             {
    663                 elem = (float)counter;
    664                 namemap[token] = counter++;
    665             }
    666             else
    667                 elem = (float)it->second;
    668             type = VAR_CATEGORICAL;
    669         }
    670         else
    671             type = VAR_ORDERED;
    672     }
    673 
    674     void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
    675     {
    676         const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
    677           "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
    678         const char* str = s.c_str();
    679         int specCounter = 0;
    680 
    681         vtypes.resize(nvars);
    682 
    683         for( int k = 0; k < 2; k++ )
    684         {
    685             const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
    686             int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
    687             if( ptr ) // parse ord/cat str
    688             {
    689                 char* stopstring = NULL;
    690 
    691                 if( ptr[3] == '\0' )
    692                 {
    693                     for( int i = 0; i < nvars; i++ )
    694                         vtypes[i] = (uchar)tp;
    695                     specCounter = nvars;
    696                     break;
    697                 }
    698 
    699                 if ( ptr[3] != '[')
    700                     CV_Error( CV_StsBadArg, errmsg );
    701 
    702                 ptr += 4; // pass "ord["
    703                 do
    704                 {
    705                     int b1 = (int)strtod( ptr, &stopstring );
    706                     if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
    707                         CV_Error( CV_StsBadArg, errmsg );
    708                     ptr = stopstring + 1;
    709                     if( (stopstring[0] == ',') || (stopstring[0] == ']'))
    710                     {
    711                         CV_Assert( 0 <= b1 && b1 < nvars );
    712                         vtypes[b1] = (uchar)tp;
    713                         specCounter++;
    714                     }
    715                     else
    716                     {
    717                         if( stopstring[0] == '-')
    718                         {
    719                             int b2 = (int)strtod( ptr, &stopstring);
    720                             if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
    721                                 CV_Error( CV_StsBadArg, errmsg );
    722                             ptr = stopstring + 1;
    723                             CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
    724                             for (int i = b1; i <= b2; i++)
    725                                 vtypes[i] = (uchar)tp;
    726                             specCounter += b2 - b1 + 1;
    727                         }
    728                         else
    729                             CV_Error( CV_StsBadArg, errmsg );
    730 
    731                     }
    732                 }
    733                 while(*stopstring != ']');
    734 
    735                 if( stopstring[1] != '\0' && stopstring[1] != ',')
    736                     CV_Error( CV_StsBadArg, errmsg );
    737             }
    738         }
    739 
    740         if( specCounter != nvars )
    741             CV_Error( CV_StsBadArg, "type of some variables is not specified" );
    742     }
    743 
    744     void setTrainTestSplitRatio(double ratio, bool shuffle)
    745     {
    746         CV_Assert( 0. <= ratio && ratio <= 1. );
    747         setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
    748     }
    749 
    750     void setTrainTestSplit(int count, bool shuffle)
    751     {
    752         int i, nsamples = getNSamples();
    753         CV_Assert( 0 <= count && count < nsamples );
    754 
    755         trainSampleIdx.release();
    756         testSampleIdx.release();
    757 
    758         if( count == 0 )
    759             trainSampleIdx = sampleIdx;
    760         else if( count == nsamples )
    761             testSampleIdx = sampleIdx;
    762         else
    763         {
    764             Mat mask(1, nsamples, CV_8U);
    765             uchar* mptr = mask.ptr();
    766             for( i = 0; i < nsamples; i++ )
    767                 mptr[i] = (uchar)(i < count);
    768             trainSampleIdx.create(1, count, CV_32S);
    769             testSampleIdx.create(1, nsamples - count, CV_32S);
    770             int j0 = 0, j1 = 0;
    771             const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
    772             int* trainptr = trainSampleIdx.ptr<int>();
    773             int* testptr = testSampleIdx.ptr<int>();
    774             for( i = 0; i < nsamples; i++ )
    775             {
    776                 int idx = sptr ? sptr[i] : i;
    777                 if( mptr[i] )
    778                     trainptr[j0++] = idx;
    779                 else
    780                     testptr[j1++] = idx;
    781             }
    782             if( shuffle )
    783                 shuffleTrainTest();
    784         }
    785     }
    786 
    787     void shuffleTrainTest()
    788     {
    789         if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
    790         {
    791             int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
    792             int* trainIdx = trainSampleIdx.ptr<int>();
    793             int* testIdx = testSampleIdx.ptr<int>();
    794             RNG& rng = theRNG();
    795 
    796             for( i = 0; i < nsamples; i++)
    797             {
    798                 int a = rng.uniform(0, nsamples);
    799                 int b = rng.uniform(0, nsamples);
    800                 int* ptra = trainIdx;
    801                 int* ptrb = trainIdx;
    802                 if( a >= ntrain )
    803                 {
    804                     ptra = testIdx;
    805                     a -= ntrain;
    806                     CV_Assert( a < ntest );
    807                 }
    808                 if( b >= ntrain )
    809                 {
    810                     ptrb = testIdx;
    811                     b -= ntrain;
    812                     CV_Assert( b < ntest );
    813                 }
    814                 std::swap(ptra[a], ptrb[b]);
    815             }
    816         }
    817     }
    818 
    819     Mat getTrainSamples(int _layout,
    820                         bool compressSamples,
    821                         bool compressVars) const
    822     {
    823         if( samples.empty() )
    824             return samples;
    825 
    826         if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
    827             (!compressVars || varIdx.empty()) &&
    828             layout == _layout )
    829             return samples;
    830 
    831         int drows = getNTrainSamples(), dcols = getNVars();
    832         Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
    833         const float* src0 = samples.ptr<float>();
    834         const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
    835         const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
    836         size_t sstep0 = samples.step/samples.elemSize();
    837         size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
    838         size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
    839 
    840         if( _layout == COL_SAMPLE )
    841         {
    842             std::swap(drows, dcols);
    843             std::swap(sptr, vptr);
    844             std::swap(sstep, vstep);
    845         }
    846 
    847         Mat dsamples(drows, dcols, CV_32F);
    848 
    849         for( int i = 0; i < drows; i++ )
    850         {
    851             const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
    852             float* dst = dsamples.ptr<float>(i);
    853 
    854             for( int j = 0; j < dcols; j++ )
    855                 dst[j] = src[(vptr ? vptr[j] : j)*vstep];
    856         }
    857 
    858         return dsamples;
    859     }
    860 
    861     void getValues( int vi, InputArray _sidx, float* values ) const
    862     {
    863         Mat sidx = _sidx.getMat();
    864         int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
    865         CV_Assert( 0 <= vi && vi < getNAllVars() );
    866         CV_Assert( n >= 0 );
    867         const int* s = n > 0 ? sidx.ptr<int>() : 0;
    868         if( n == 0 )
    869             n = nsamples;
    870 
    871         size_t step = samples.step/samples.elemSize();
    872         size_t sstep = layout == ROW_SAMPLE ? step : 1;
    873         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
    874 
    875         const float* src = samples.ptr<float>() + vi*vstep;
    876         float subst = missingSubst.at<float>(vi);
    877         for( i = 0; i < n; i++ )
    878         {
    879             int j = i;
    880             if( s )
    881             {
    882                 j = s[i];
    883                 CV_Assert( 0 <= j && j < nsamples );
    884             }
    885             values[i] = src[j*sstep];
    886             if( values[i] == MISSED_VAL )
    887                 values[i] = subst;
    888         }
    889     }
    890 
    891     void getNormCatValues( int vi, InputArray _sidx, int* values ) const
    892     {
    893         float* fvalues = (float*)values;
    894         getValues(vi, _sidx, fvalues);
    895         int i, n = (int)_sidx.total();
    896         Vec2i ofs = catOfs.at<Vec2i>(vi);
    897         int m = ofs[1] - ofs[0];
    898 
    899         CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
    900         const int* cmap = &catMap.at<int>(ofs[0]);
    901         bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
    902 
    903         if( fastMap )
    904         {
    905             for( i = 0; i < n; i++ )
    906             {
    907                 int val = cvRound(fvalues[i]);
    908                 int idx = val - cmap[0];
    909                 CV_Assert(cmap[idx] == val);
    910                 values[i] = idx;
    911             }
    912         }
    913         else
    914         {
    915             for( i = 0; i < n; i++ )
    916             {
    917                 int val = cvRound(fvalues[i]);
    918                 int a = 0, b = m, c = -1;
    919 
    920                 while( a < b )
    921                 {
    922                     c = (a + b) >> 1;
    923                     if( val < cmap[c] )
    924                         b = c;
    925                     else if( val > cmap[c] )
    926                         a = c+1;
    927                     else
    928                         break;
    929                 }
    930 
    931                 CV_DbgAssert( c >= 0 && val == cmap[c] );
    932                 values[i] = c;
    933             }
    934         }
    935     }
    936 
    937     void getSample(InputArray _vidx, int sidx, float* buf) const
    938     {
    939         CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
    940         Mat vidx = _vidx.getMat();
    941         int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
    942         CV_Assert( n >= 0 );
    943         const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
    944         if( n == 0 )
    945             n = nvars;
    946 
    947         size_t step = samples.step/samples.elemSize();
    948         size_t sstep = layout == ROW_SAMPLE ? step : 1;
    949         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
    950 
    951         const float* src = samples.ptr<float>() + sidx*sstep;
    952         for( i = 0; i < n; i++ )
    953         {
    954             int j = i;
    955             if( vptr )
    956             {
    957                 j = vptr[i];
    958                 CV_Assert( 0 <= j && j < nvars );
    959             }
    960             buf[i] = src[j*vstep];
    961         }
    962     }
    963 
    964     FILE* file;
    965     int layout;
    966     Mat samples, missing, varType, varIdx, responses, missingSubst;
    967     Mat sampleIdx, trainSampleIdx, testSampleIdx;
    968     Mat sampleWeights, catMap, catOfs;
    969     Mat normCatResponses, classLabels, classCounters;
    970     MapType nameMap;
    971 };
    972 
    973 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
    974                                       int headerLines,
    975                                       int responseStartIdx,
    976                                       int responseEndIdx,
    977                                       const String& varTypeSpec,
    978                                       char delimiter, char missch)
    979 {
    980     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    981     if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
    982         td.release();
    983     return td;
    984 }
    985 
    986 Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
    987                                  InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
    988                                  InputArray varType)
    989 {
    990     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    991     td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
    992     return td;
    993 }
    994 
    995 }}
    996 
    997 /* End of file. */
    998