Home | History | Annotate | Download | only in script
      1 // compile.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Class to to compile a binary Fst from textual input.
     20 
     21 #ifndef FST_SCRIPT_COMPILE_IMPL_H_
     22 #define FST_SCRIPT_COMPILE_IMPL_H_
     23 
     24 #include <tr1/unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <sstream>
     28 #include <string>
     29 #include <vector>
     30 using std::vector;
     31 
     32 #include <iostream>
     33 #include <fstream>
     34 #include <sstream>
     35 #include <fst/fst.h>
     36 #include <fst/util.h>
     37 #include <fst/vector-fst.h>
     38 
     39 DECLARE_string(fst_field_separator);
     40 
     41 namespace fst {
     42 
     43 // Compile a binary Fst from textual input, helper class for fstcompile.cc
     44 // WARNING: Stand-alone use of this class not recommended, most code should
     45 // read/write using the binary format which is much more efficient.
     46 template <class A> class FstCompiler {
     47  public:
     48   typedef A Arc;
     49   typedef typename A::StateId StateId;
     50   typedef typename A::Label Label;
     51   typedef typename A::Weight Weight;
     52 
     53   // WARNING: use of 'allow_negative_labels = true' not recommended; may
     54   // cause conflicts
     55   FstCompiler(istream &istrm, const string &source,
     56             const SymbolTable *isyms, const SymbolTable *osyms,
     57             const SymbolTable *ssyms, bool accep, bool ikeep,
     58               bool okeep, bool nkeep, bool allow_negative_labels = false)
     59       : nline_(0), source_(source),
     60         isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
     61         nstates_(0), keep_state_numbering_(nkeep),
     62         allow_negative_labels_(allow_negative_labels) {
     63     char line[kLineLen];
     64     while (istrm.getline(line, kLineLen)) {
     65       ++nline_;
     66       vector<char *> col;
     67       string separator = FLAGS_fst_field_separator + "\n";
     68       SplitToVector(line, separator.c_str(), &col, true);
     69       if (col.size() == 0 || col[0][0] == '\0')  // empty line
     70         continue;
     71       if (col.size() > 5 ||
     72           (col.size() > 4 && accep) ||
     73           (col.size() == 3 && !accep)) {
     74         FSTERROR() << "FstCompiler: Bad number of columns, source = "
     75                    << source_
     76                    << ", line = " << nline_;
     77         fst_.SetProperties(kError, kError);
     78         return;
     79       }
     80       StateId s = StrToStateId(col[0]);
     81       while (s >= fst_.NumStates())
     82         fst_.AddState();
     83       if (nline_ == 1)
     84         fst_.SetStart(s);
     85 
     86       Arc arc;
     87       StateId d = s;
     88       switch (col.size()) {
     89       case 1:
     90         fst_.SetFinal(s, Weight::One());
     91         break;
     92       case 2:
     93         fst_.SetFinal(s, StrToWeight(col[1], true));
     94         break;
     95       case 3:
     96         arc.nextstate = d = StrToStateId(col[1]);
     97         arc.ilabel = StrToILabel(col[2]);
     98         arc.olabel = arc.ilabel;
     99         arc.weight = Weight::One();
    100         fst_.AddArc(s, arc);
    101         break;
    102       case 4:
    103         arc.nextstate = d = StrToStateId(col[1]);
    104         arc.ilabel = StrToILabel(col[2]);
    105         if (accep) {
    106           arc.olabel = arc.ilabel;
    107           arc.weight = StrToWeight(col[3], false);
    108         } else {
    109           arc.olabel = StrToOLabel(col[3]);
    110           arc.weight = Weight::One();
    111         }
    112         fst_.AddArc(s, arc);
    113         break;
    114       case 5:
    115         arc.nextstate = d = StrToStateId(col[1]);
    116         arc.ilabel = StrToILabel(col[2]);
    117         arc.olabel = StrToOLabel(col[3]);
    118         arc.weight = StrToWeight(col[4], false);
    119         fst_.AddArc(s, arc);
    120       }
    121       while (d >= fst_.NumStates())
    122         fst_.AddState();
    123     }
    124     if (ikeep)
    125       fst_.SetInputSymbols(isyms);
    126     if (okeep)
    127       fst_.SetOutputSymbols(osyms);
    128   }
    129 
    130   const VectorFst<A> &Fst() const {
    131     return fst_;
    132   }
    133 
    134  private:
    135   // Maximum line length in text file.
    136   static const int kLineLen = 8096;
    137 
    138   int64 StrToId(const char *s, const SymbolTable *syms,
    139                 const char *name, bool allow_negative = false) const {
    140     int64 n = 0;
    141 
    142     if (syms) {
    143       n = syms->Find(s);
    144       if (n == -1 || (!allow_negative && n < 0)) {
    145         FSTERROR() << "FstCompiler: Symbol \"" << s
    146                    << "\" is not mapped to any integer " << name
    147                    << ", symbol table = " << syms->Name()
    148                    << ", source = " << source_ << ", line = " << nline_;
    149         fst_.SetProperties(kError, kError);
    150       }
    151     } else {
    152       char *p;
    153       n = strtoll(s, &p, 10);
    154       if (p < s + strlen(s) || (!allow_negative && n < 0)) {
    155         FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
    156                    << "\", source = " << source_ << ", line = " << nline_;
    157         fst_.SetProperties(kError, kError);
    158       }
    159     }
    160     return n;
    161   }
    162 
    163   StateId StrToStateId(const char *s) {
    164     StateId n = StrToId(s, ssyms_, "state ID");
    165 
    166     if (keep_state_numbering_)
    167       return n;
    168 
    169     // remap state IDs to make dense set
    170     typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n);
    171     if (it == states_.end()) {
    172       states_[n] = nstates_;
    173       return nstates_++;
    174     } else {
    175       return it->second;
    176     }
    177   }
    178 
    179   StateId StrToILabel(const char *s) const {
    180     return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
    181   }
    182 
    183   StateId StrToOLabel(const char *s) const {
    184     return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
    185   }
    186 
    187   Weight StrToWeight(const char *s, bool allow_zero) const {
    188     Weight w;
    189     istringstream strm(s);
    190     strm >> w;
    191     if (!strm || (!allow_zero && w == Weight::Zero())) {
    192       FSTERROR() << "FstCompiler: Bad weight = \"" << s
    193                  << "\", source = " << source_ << ", line = " << nline_;
    194       fst_.SetProperties(kError, kError);
    195       w = Weight::NoWeight();
    196     }
    197     return w;
    198   }
    199 
    200   mutable VectorFst<A> fst_;
    201   size_t nline_;
    202   string source_;                      // text FST source name
    203   const SymbolTable *isyms_;           // ilabel symbol table
    204   const SymbolTable *osyms_;           // olabel symbol table
    205   const SymbolTable *ssyms_;           // slabel symbol table
    206   unordered_map<StateId, StateId> states_;  // state ID map
    207   StateId nstates_;                    // number of seen states
    208   bool keep_state_numbering_;
    209   bool allow_negative_labels_;         // not recommended; may cause conflicts
    210 
    211   DISALLOW_COPY_AND_ASSIGN(FstCompiler);
    212 };
    213 
    214 }  // namespace fst
    215 
    216 #endif  // FST_SCRIPT_COMPILE_IMPL_H_
    217