Home | History | Annotate | Download | only in grxmlcompile
      1 #ifndef __FST_IO_H__
      2 #define __FST_IO_H__
      3 
      4 // fst-io.h
      5 // This is a copy of the OPENFST SDK application sample files ...
      6 // except for the main functions ifdef'ed out
      7 // 2007, 2008 Nuance Communications
      8 //
      9 // print-main.h compile-main.h
     10 //
     11 // Licensed under the Apache License, Version 2.0 (the "License");
     12 // you may not use this file except in compliance with the License.
     13 // You may obtain a copy of the License at
     14 //
     15 //      http://www.apache.org/licenses/LICENSE-2.0
     16 //
     17 // Unless required by applicable law or agreed to in writing, software
     18 // distributed under the License is distributed on an "AS IS" BASIS,
     19 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     20 // See the License for the specific language governing permissions and
     21 // limitations under the License.
     22 //
     23 //
     24 // \file
     25 // Classes and functions to compile a binary Fst from textual input.
     26 // Includes helper function for fstcompile.cc that templates the main
     27 // on the arc type to support multiple and extensible arc types.
     28 
     29 #include <fstream>
     30 #include <sstream>
     31 
     32 #include "fst/lib/fst.h"
     33 #include "fst/lib/fstlib.h"
     34 #include "fst/lib/fst-decl.h"
     35 #include "fst/lib/vector-fst.h"
     36 #include "fst/lib/arcsort.h"
     37 #include "fst/lib/invert.h"
     38 
     39 namespace fst {
     40 
     41   template <class A> class FstPrinter {
     42   public:
     43     typedef A Arc;
     44     typedef typename A::StateId StateId;
     45     typedef typename A::Label Label;
     46     typedef typename A::Weight Weight;
     47 
     48     FstPrinter(const Fst<A> &fst,
     49 	       const SymbolTable *isyms,
     50 	       const SymbolTable *osyms,
     51 	       const SymbolTable *ssyms,
     52 	       bool accep)
     53       : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
     54       accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {}
     55 
     56     // Print Fst to an output strm
     57     void Print(ostream *ostrm, const string &dest) {
     58       ostrm_ = ostrm;
     59       dest_ = dest;
     60       StateId start = fst_.Start();
     61       if (start == kNoStateId)
     62 	return;
     63       // initial state first
     64       PrintState(start);
     65       for (StateIterator< Fst<A> > siter(fst_);
     66 	   !siter.Done();
     67 	   siter.Next()) {
     68 	StateId s = siter.Value();
     69 	if (s != start)
     70 	  PrintState(s);
     71       }
     72     }
     73 
     74   private:
     75     // Maximum line length in text file.
     76     static const int kLineLen = 8096;
     77 
     78     void PrintId(int64 id, const SymbolTable *syms,
     79 		 const char *name) const {
     80       if (syms) {
     81 	string symbol = syms->Find(id);
     82 	if (symbol == "") {
     83 	  LOG(ERROR) << "FstPrinter: Integer " << id
     84 		     << " is not mapped to any textual symbol"
     85 		     << ", symbol table = " << syms->Name()
     86 		     << ", destination = " << dest_;
     87 	  exit(1);
     88 	}
     89 	*ostrm_ << symbol;
     90       } else {
     91 	*ostrm_ << id;
     92       }
     93     }
     94 
     95     void PrintStateId(StateId s) const {
     96       PrintId(s, ssyms_, "state ID");
     97     }
     98 
     99     void PrintILabel(Label l) const {
    100       PrintId(l, isyms_, "arc input label");
    101     }
    102 
    103     void PrintOLabel(Label l) const {
    104       PrintId(l, osyms_, "arc output label");
    105     }
    106 
    107     void PrintState(StateId s) const {
    108       bool output = false;
    109       for (ArcIterator< Fst<A> > aiter(fst_, s);
    110 	   !aiter.Done();
    111 	   aiter.Next()) {
    112 	Arc arc = aiter.Value();
    113 	PrintStateId(s);
    114 	*ostrm_ << "\t";
    115 	PrintStateId(arc.nextstate);
    116 	*ostrm_ << "\t";
    117 	PrintILabel(arc.ilabel);
    118 	if (!accep_) {
    119 	  *ostrm_ << "\t";
    120 	  PrintOLabel(arc.olabel);
    121 	}
    122 	if (arc.weight != Weight::One())
    123 	  *ostrm_ << "\t" << arc.weight;
    124 	*ostrm_ << "\n";
    125 	output = true;
    126       }
    127       Weight final = fst_.Final(s);
    128       if (final != Weight::Zero() || !output) {
    129 	PrintStateId(s);
    130 	if (final != Weight::One()) {
    131 	  *ostrm_ << "\t" << final;
    132 	}
    133 	*ostrm_ << "\n";
    134       }
    135     }
    136 
    137     const Fst<A> &fst_;
    138     const SymbolTable *isyms_;     // ilabel symbol table
    139     const SymbolTable *osyms_;     // olabel symbol table
    140     const SymbolTable *ssyms_;     // slabel symbol table
    141     bool accep_;                   // print as acceptor when possible
    142     ostream *ostrm_;                // binary FST destination
    143     string dest_;                  // binary FST destination name
    144     DISALLOW_EVIL_CONSTRUCTORS(FstPrinter);
    145   };
    146 
    147 #if 0
    148   // Main function for fstprint templated on the arc type.
    149   template <class Arc>
    150     int PrintMain(int argc, char **argv, istream &istrm,
    151 		  const FstReadOptions &opts) {
    152     Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts);
    153     if (!fst) return 1;
    154 
    155     string dest = "standard output";
    156     ostream *ostrm = &std::cout;
    157     if (argc == 3) {
    158       dest = argv[2];
    159       ostrm = new ofstream(argv[2]);
    160       if (!*ostrm) {
    161 	LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2];
    162 	return 0;
    163       }
    164     }
    165     ostrm->precision(9);
    166 
    167     const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
    168 
    169     if (!FLAGS_isymbols.empty() && !FLAGS_numeric) {
    170       isyms = SymbolTable::ReadText(FLAGS_isymbols);
    171       if (!isyms) exit(1);
    172     }
    173 
    174     if (!FLAGS_osymbols.empty() && !FLAGS_numeric) {
    175       osyms = SymbolTable::ReadText(FLAGS_osymbols);
    176       if (!osyms) exit(1);
    177     }
    178 
    179     if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) {
    180       ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
    181       if (!ssyms) exit(1);
    182     }
    183 
    184     if (!isyms && !FLAGS_numeric)
    185       isyms = fst->InputSymbols();
    186     if (!osyms && !FLAGS_numeric)
    187       osyms = fst->OutputSymbols();
    188 
    189     FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor);
    190     fstprinter.Print(ostrm, dest);
    191 
    192     if (isyms && !FLAGS_save_isymbols.empty())
    193       isyms->WriteText(FLAGS_save_isymbols);
    194 
    195     if (osyms && !FLAGS_save_osymbols.empty())
    196       osyms->WriteText(FLAGS_save_osymbols);
    197 
    198     if (ostrm != &std::cout)
    199       delete ostrm;
    200     return 0;
    201   }
    202 #endif
    203 
    204 
    205   template <class A> class FstReader {
    206   public:
    207     typedef A Arc;
    208     typedef typename A::StateId StateId;
    209     typedef typename A::Label Label;
    210     typedef typename A::Weight Weight;
    211 
    212     FstReader(istream &istrm, const string &source,
    213 	      const SymbolTable *isyms, const SymbolTable *osyms,
    214 	      const SymbolTable *ssyms, bool accep, bool ikeep,
    215 	      bool okeep, bool nkeep)
    216       : nline_(0), source_(source),
    217       isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
    218       nstates_(0), keep_state_numbering_(nkeep) {
    219       char line[kLineLen];
    220       while (istrm.getline(line, kLineLen)) {
    221 	++nline_;
    222 	vector<char *> col;
    223 	SplitToVector(line, "\n\t ", &col, true);
    224 	if (col.size() == 0 || col[0][0] == '\0')  // empty line
    225 	  continue;
    226 	if (col.size() > 5 ||
    227 	    col.size() > 4 && accep ||
    228 	    col.size() == 3 && !accep) {
    229 	  LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
    230 		     << ", line = " << nline_;
    231 	  exit(1);
    232 	}
    233 	StateId s = StrToStateId(col[0]);
    234 	while (s >= fst_.NumStates())
    235 	  fst_.AddState();
    236 	if (nline_ == 1)
    237 	  fst_.SetStart(s);
    238 
    239 	Arc arc;
    240 	StateId d = s;
    241 	switch (col.size()) {
    242 	case 1:
    243 	  fst_.SetFinal(s, Weight::One());
    244 	  break;
    245 	case 2:
    246 	  fst_.SetFinal(s, StrToWeight(col[1], true));
    247 	  break;
    248 	case 3:
    249 	  arc.nextstate = d = StrToStateId(col[1]);
    250 	  arc.ilabel = StrToILabel(col[2]);
    251 	  arc.olabel = arc.ilabel;
    252 	  arc.weight = Weight::One();
    253 	  fst_.AddArc(s, arc);
    254 	  break;
    255 	case 4:
    256 	  arc.nextstate = d = StrToStateId(col[1]);
    257 	  arc.ilabel = StrToILabel(col[2]);
    258 	  if (accep) {
    259 	    arc.olabel = arc.ilabel;
    260 	    arc.weight = StrToWeight(col[3], false);
    261 	  } else {
    262 	    arc.olabel = StrToOLabel(col[3]);
    263 	    arc.weight = Weight::One();
    264 	  }
    265 	  fst_.AddArc(s, arc);
    266 	  break;
    267 	case 5:
    268 	  arc.nextstate = d = StrToStateId(col[1]);
    269 	  arc.ilabel = StrToILabel(col[2]);
    270 	  arc.olabel = StrToOLabel(col[3]);
    271 	  arc.weight = StrToWeight(col[4], false);
    272 	  fst_.AddArc(s, arc);
    273 	}
    274 	while (d >= fst_.NumStates())
    275 	  fst_.AddState();
    276       }
    277       if (ikeep)
    278 	fst_.SetInputSymbols(isyms);
    279       if (okeep)
    280 	fst_.SetOutputSymbols(osyms);
    281     }
    282 
    283     const VectorFst<A> &Fst() const { return fst_; }
    284 
    285   private:
    286     // Maximum line length in text file.
    287     static const int kLineLen = 8096;
    288 
    289     int64 StrToId(const char *s, const SymbolTable *syms,
    290 		  const char *name) const {
    291       int64 n;
    292 
    293       if (syms) {
    294 	n = syms->Find(s);
    295 	if (n < 0) {
    296 	  LOG(ERROR) << "FstReader: Symbol \"" << s
    297 		     << "\" is not mapped to any integer " << name
    298 		     << ", symbol table = " << syms->Name()
    299 		     << ", source = " << source_ << ", line = " << nline_;
    300 	  exit(1);
    301 	}
    302       } else {
    303 	char *p;
    304 	n = strtoll(s, &p, 10);
    305 	if (p < s + strlen(s) || n < 0) {
    306 	  LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
    307 		     << "\", source = " << source_ << ", line = " << nline_;
    308 	  exit(1);
    309 	}
    310       }
    311       return n;
    312     }
    313 
    314     StateId StrToStateId(const char *s) {
    315       StateId n = StrToId(s, ssyms_, "state ID");
    316 
    317       if (keep_state_numbering_)
    318 	return n;
    319 
    320       // remap state IDs to make dense set
    321       typename hash_map<StateId, StateId>::const_iterator it = states_.find(n);
    322       if (it == states_.end()) {
    323 	states_[n] = nstates_;
    324 	return nstates_++;
    325       } else {
    326 	return it->second;
    327       }
    328     }
    329 
    330     StateId StrToILabel(const char *s) const {
    331       return StrToId(s, isyms_, "arc ilabel");
    332     }
    333 
    334     StateId StrToOLabel(const char *s) const {
    335       return StrToId(s, osyms_, "arc olabel");
    336     }
    337 
    338     Weight StrToWeight(const char *s, bool allow_zero) const {
    339       Weight w;
    340       istringstream strm(s);
    341       strm >> w;
    342       if (strm.fail() || !allow_zero && w == Weight::Zero()) {
    343 	LOG(ERROR) << "FstReader: Bad weight = \"" << s
    344 		   << "\", source = " << source_ << ", line = " << nline_;
    345 	exit(1);
    346       }
    347       return w;
    348     }
    349 
    350     VectorFst<A> fst_;
    351     size_t nline_;
    352     string source_;                      // text FST source name
    353     const SymbolTable *isyms_;           // ilabel symbol table
    354     const SymbolTable *osyms_;           // olabel symbol table
    355     const SymbolTable *ssyms_;           // slabel symbol table
    356     hash_map<StateId, StateId> states_;  // state ID map
    357     StateId nstates_;                    // number of seen states
    358     bool keep_state_numbering_;
    359     DISALLOW_EVIL_CONSTRUCTORS(FstReader);
    360   };
    361 
    362 #if 0
    363   // Main function for fstcompile templated on the arc type.  Last two
    364   // arguments unneeded since fstcompile passes the arc type as a flag
    365   // unlike the other mains, which infer the arc type from an input Fst.
    366   template <class Arc>
    367     int CompileMain(int argc, char **argv, istream& /* strm */,
    368 		    const FstReadOptions & /* opts */) {
    369     char *ifilename = "standard input";
    370     istream *istrm = &std::cin;
    371     if (argc > 1 && strcmp(argv[1], "-") != 0) {
    372       ifilename = argv[1];
    373       istrm = new ifstream(ifilename);
    374       if (!*istrm) {
    375 	LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename;
    376 	return 1;
    377       }
    378     }
    379     const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
    380 
    381     if (!FLAGS_isymbols.empty()) {
    382       isyms = SymbolTable::ReadText(FLAGS_isymbols);
    383       if (!isyms) exit(1);
    384     }
    385 
    386     if (!FLAGS_osymbols.empty()) {
    387       osyms = SymbolTable::ReadText(FLAGS_osymbols);
    388       if (!osyms) exit(1);
    389     }
    390 
    391     if (!FLAGS_ssymbols.empty()) {
    392       ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
    393       if (!ssyms) exit(1);
    394     }
    395 
    396     FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms,
    397 			     FLAGS_acceptor, FLAGS_keep_isymbols,
    398 			     FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
    399 
    400     const Fst<Arc> *fst = &fstreader.Fst();
    401     if (FLAGS_fst_type != "vector") {
    402       fst = Convert<Arc>(*fst, FLAGS_fst_type);
    403       if (!fst) return 1;
    404     }
    405     fst->Write(argc > 2 ? argv[2] : "");
    406     if (istrm != &std::cin)
    407       delete istrm;
    408     return 0;
    409   }
    410 #endif
    411 
    412 }  // namespace fst
    413 
    414 #endif /* __FST_IO_H__ */
    415 
    416