Home | History | Annotate | Download | only in make_cfst
      1 /*---------------------------------------------------------------------------*
      2  *  make_cfst.cpp                                                            *
      3  *                                                                           *
      4  *  Copyright 2007, 2008 Nuance Communciations, Inc.                               *
      5  *                                                                           *
      6  *  Licensed under the Apache License, Version 2.0 (the 'License');          *
      7  *  you may not use this file except in compliance with the License.         *
      8  *                                                                           *
      9  *  You may obtain a copy of the License at                                  *
     10  *      http://www.apache.org/licenses/LICENSE-2.0                           *
     11  *                                                                           *
     12  *  Unless required by applicable law or agreed to in writing, software      *
     13  *  distributed under the License is distributed on an 'AS IS' BASIS,        *
     14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
     15  *  See the License for the specific language governing permissions and      *
     16  *  limitations under the License.                                           *
     17  *                                                                           *
     18  *---------------------------------------------------------------------------*/
     19 
     20 #include "ptypes.h"
     21 #include "srec_arb.h"
     22 #include "simapi.h"
     23 
     24 #include "fst/lib/fstlib.h"
     25 #include "fst/lib/fst-decl.h"
     26 #include "fst/lib/vector-fst.h"
     27 #include "fst/lib/arcsort.h"
     28 #include "fst/lib/invert.h"
     29 #include "fst/lib/rmepsilon.h"
     30 
     31 #define MAX_LINE_LENGTH     256
     32 #define MAX_PRONS_LENGTH 1024
     33 #define EPSILON_LABEL 0
     34 #define MAX_MODELS 1024
     35 #define MAXPHID 8888
     36 #define MAX_PHONEMES 128
     37 
     38 using namespace fst;
     39 
     40 int usage(const char* prog)
     41 {
     42   printf("usage: %s  -phones models/phones.map -models models/models128x.map -cfst models/generic.C -swiarb models/generic.swiarb\n", prog);
     43   return 1;
     44 }
     45 
     46 typedef struct Minifst_t
     47 {
     48   char lcontexts[MAX_PHONEMES];
     49   char rcontexts[MAX_PHONEMES];
     50   int modelId;
     51   int stateSt;
     52   int stateEn;
     53   phonemeID phonemeId;
     54   unsigned char phonemeCode;
     55   int lcontext_state[MAX_PHONEMES];
     56   int rcontext_state[MAX_PHONEMES];
     57 } Minifst;
     58 
     59 int main(int argc, char **argv)
     60 {
     61   char* cfstFilename = 0;
     62   char* swiarbFilename = 0;
     63   char* phonesMap;
     64   char* modelsMap;
     65   int i;
     66   phonemeID lphonId, cphonId, rphonId;
     67   unsigned char cphon;
     68   modelID modelId, max_modelId = 0;
     69   int stateSt, stateEn;
     70   int stateN, stateNp1;
     71   int rc;
     72   Minifst minifst[MAX_MODELS];
     73   int do_show_text = 1;
     74   int do_until_step = 99;
     75 
     76   /* initial memory */
     77   rc = PMemInit();
     78   ASSERT( rc == ESR_SUCCESS);
     79 
     80   // A vector FST is a general mutable FST
     81   fst::StdVectorFst myCfst;
     82   // fst::Fst<fst::StdArc> myCfst;
     83 
     84   if(argc <= 1)
     85     return usage(argv[0]);
     86 
     87   for(i=1; i<argc; i++)
     88   {
     89     if(0) ;
     90     else if(!strcmp(argv[i],"-phones"))
     91       phonesMap = argv[++i];
     92     else if(!strcmp(argv[i],"-models"))
     93       modelsMap = argv[++i];
     94     else if(!strcmp(argv[i],"-cfst"))
     95       cfstFilename = argv[++i];
     96     else if(!strcmp(argv[i],"-step"))
     97       do_until_step = atoi(argv[++i]);
     98     else if(!strcmp(argv[i],"-swiarb"))
     99       swiarbFilename = argv[++i];
    100     else {
    101       return usage(argv[0]);
    102     }
    103   }
    104 
    105   printf("loading %s ...\n", swiarbFilename);
    106   CA_Arbdata* ca_arbdata = CA_LoadArbdata(swiarbFilename);
    107   srec_arbdata *allotree = (srec_arbdata*)ca_arbdata;
    108 
    109 
    110   /*-------------------------------------------------------------------------
    111    *
    112    *       /---<---<---<---<---<---<---\
    113    *      /                             \
    114    *     /       -wb--         -wb-      \
    115    *    /        \   /         \  /       \
    116    *   0 ---#--->  n ----M---> n+1 ---#---> 1
    117    *
    118    *
    119    *
    120    *
    121    *-------------------------------------------------------------------------
    122    */
    123 
    124   // Adds state 0 to the initially empty FST and make it the start state.
    125   stateSt = myCfst.AddState();   // 1st state will be state 0 (returned by AddState)
    126   stateEn = myCfst.AddState();
    127   myCfst.SetStart(stateSt);  // arg is state ID
    128   myCfst.SetFinal(stateEn, 0.0);  // 1st arg is state ID, 2nd arg weight
    129   myCfst.AddArc(stateEn, fst::StdArc(EPSILON_LABEL, EPSILON_LABEL, 0.0, stateSt));
    130 
    131   phonemeID silencePhonId = 0;
    132   modelID silenceModelId = 0;
    133   silenceModelId = (modelID)get_modelid_for_pic(allotree, silencePhonId, silencePhonId, silencePhonId);
    134   // silenceModelId += MODEL_LABEL_OFFSET; #define MODEL_LABEL_OFFSET 128
    135 
    136   for(modelId=0; modelId<MAX_MODELS; modelId++) {
    137     minifst[modelId].modelId = MAXmodelID;
    138     minifst[modelId].stateSt = minifst[modelId].stateEn = 0;
    139     minifst[modelId].phonemeId = MAXphonemeID;
    140     minifst[modelId].phonemeCode = 0;
    141     for(i=0;i<MAX_PHONEMES;i++) {
    142       minifst[modelId].lcontexts[i] = minifst[modelId].rcontexts[i] = 0;
    143       minifst[modelId].lcontext_state[i] = minifst[modelId].rcontext_state[i] = 0;
    144     }
    145   }
    146 
    147   for(cphonId=0; cphonId<allotree->num_phonemes && cphonId<MAXPHID; cphonId++) {
    148     cphon = allotree->pdata[cphonId].code;
    149     printf("processing phoneme %d of %d %d %c\n", cphonId, allotree->num_phonemes, cphon, cphon);
    150 
    151     for(lphonId=0; lphonId<allotree->num_phonemes && lphonId<MAXPHID; lphonId++) {
    152       unsigned char lphon = allotree->pdata[lphonId].code;
    153       for(rphonId=0; rphonId<allotree->num_phonemes && rphonId<MAXPHID; rphonId++) {
    154 	unsigned char rphon = allotree->pdata[rphonId].code;
    155 	if( 1|| cphon=='a') { //22222
    156 	  modelId = (modelID)get_modelid_for_pic(allotree, lphonId, cphonId, rphonId);
    157 	} else {
    158 	  modelId = (modelID)get_modelid_for_pic(allotree, 0, cphonId, 0);
    159 	}
    160 	if(modelId == MAXmodelID) {
    161 	  printf("error while get_modelid_for_pic( %p, %d, %d, %d)\n",
    162 		 allotree, lphonId, cphonId, rphonId);
    163 	  continue;
    164 	} else
    165 	  if(do_show_text) printf("%c %c %c hmm%03d_%c %d %d %d\n", lphon, cphon, rphon, modelId, cphon, lphonId, cphonId, rphonId);
    166 	ASSERT(modelId < MAX_MODELS);
    167 	minifst[ modelId].phonemeId = cphonId;
    168 	minifst[ modelId].phonemeCode = cphon;
    169 	minifst[ modelId].modelId = modelId;
    170 	minifst[ modelId].lcontexts[lphonId] = 1;
    171 	minifst[ modelId].rcontexts[rphonId] = 1;
    172 	if(modelId>max_modelId) max_modelId = modelId;
    173       }
    174     }
    175   }
    176 
    177   printf("adding model arcs .. max_modelId %d\n",max_modelId);
    178   for(modelId=0; modelId<=max_modelId; modelId++) {
    179     if( minifst[modelId].modelId == MAXmodelID) continue;
    180     cphon = minifst[modelId].phonemeCode;
    181     minifst[modelId].stateSt = (stateN = myCfst.AddState());
    182     minifst[modelId].stateEn = (stateNp1 = myCfst.AddState()); /* n plus 1 */
    183     myCfst.AddArc( stateN, fst::StdArc(cphon,modelId,0.0,stateNp1));
    184     myCfst.AddArc( stateNp1, fst::StdArc(WORD_BOUNDARY,WORD_BOUNDARY,0.0,stateNp1));
    185 
    186     if(do_show_text) printf("%d\t\%d\t%c\t\%d\n", stateN,stateNp1,cphon,modelId);
    187 #if 1
    188     for( lphonId=0; lphonId<allotree->num_phonemes; lphonId++) {
    189       minifst[modelId].lcontext_state[lphonId] = myCfst.AddState();
    190       myCfst.AddArc( minifst[modelId].lcontext_state[lphonId],
    191 		  fst::StdArc(EPSILON_LABEL,EPSILON_LABEL,0.0,
    192 			      minifst[modelId].stateSt));
    193 
    194     }
    195     for( rphonId=0; rphonId<allotree->num_phonemes; rphonId++) {
    196       minifst[modelId].rcontext_state[rphonId] = myCfst.AddState();
    197       myCfst.AddArc( minifst[modelId].stateEn,
    198 		  fst::StdArc(EPSILON_LABEL,EPSILON_LABEL,0.0,
    199 			      minifst[modelId].rcontext_state[rphonId]));
    200     }
    201 #endif
    202   }
    203 #if 1
    204   printf("adding cross-connections\n");
    205   for( modelId=0; modelId<=max_modelId; modelId++) {
    206     printf("processing model %d\n", modelId);
    207     if( minifst[modelId].modelId == MAXmodelID) continue;
    208     cphonId = minifst[modelId].phonemeId;
    209     for( modelID mId=0; mId<=max_modelId; mId++) {
    210       if( minifst[mId].modelId != MAXmodelID &&
    211 	  // minifst[mId].phonemeId == rphonId &&
    212 	  minifst[modelId].rcontexts[ minifst[mId].phonemeId] == 1 &&
    213 	  minifst[mId].lcontexts[ cphonId] == 1) {
    214 	myCfst.AddArc( minifst[modelId].stateEn,
    215 		    fst::StdArc(EPSILON_LABEL,EPSILON_LABEL,0.0,
    216 				minifst[mId].stateSt));
    217       }
    218     }
    219   }
    220   /* start node connections */
    221   myCfst.AddArc( stateSt,
    222 	      fst::StdArc(EPSILON_LABEL, EPSILON_LABEL, 0.0,
    223 			  minifst[silenceModelId].stateSt));
    224   myCfst.AddArc(  minifst[silenceModelId].stateEn,
    225 	      fst::StdArc(EPSILON_LABEL, EPSILON_LABEL, 0.0, stateEn));
    226 #endif
    227 
    228   fst::StdVectorFst fst2;
    229   fst::StdVectorFst* ofst = &myCfst;
    230   if(do_until_step>0) {
    231     printf("invert\n");
    232     fst::Invert(&myCfst);
    233     bool FLAGS_connect = true;
    234     if(do_until_step>1) {
    235       printf("rmepsilon\n");
    236       fst::RmEpsilon( &myCfst, FLAGS_connect);
    237       if(do_until_step>2) {
    238 	printf("determinize\n");
    239 	fst::Determinize(myCfst, &fst2);
    240 	ofst = &fst2;
    241 	if(do_until_step>3) {
    242 	  printf("arcsort olabels\n");
    243 	  fst::ArcSort(&fst2, fst::StdOLabelCompare());
    244 	}
    245       }
    246     }
    247   }
    248 
    249 #if 0
    250   for(fst::SymbolTableIterator syms_iter( *syms); !syms_iter.Done(); syms_iter.Next() ) {
    251     int value = (int)syms_iter.Value();
    252     const char* key = syms_iter.Symbol();
    253   }
    254 #endif
    255 
    256   printf("writing output file %s\n", cfstFilename);
    257 
    258   // We can save this FST to a file with:
    259   /* fail compilation if char and LCHAR aren't the same! */
    260 
    261   { char zzz[ 1 - (sizeof(LCHAR)!=sizeof(char))]; zzz[0] = 0; }
    262   ofst->Write((const char*)cfstFilename);
    263 
    264   CA_FreeArbdata( ca_arbdata);
    265 
    266   PMemShutdown();
    267 
    268   //  CLEANUP:
    269   return (int)rc;
    270 }
    271 
    272 
    273