Home | History | Annotate | Download | only in script
      1 // compile.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Class to to compile a binary Fst from textual input.
     20 
     21 #ifndef FST_SCRIPT_COMPILE_IMPL_H_
     22 #define FST_SCRIPT_COMPILE_IMPL_H_
     23 
     24 #include <unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <sstream>
     28 #include <string>
     29 #include <vector>
     30 using std::vector;
     31 
     32 #include <iostream>
     33 #include <fstream>
     34 #include <fst/fst.h>
     35 #include <fst/util.h>
     36 #include <fst/vector-fst.h>
     37 
     38 DECLARE_string(fst_field_separator);
     39 
     40 namespace fst {
     41 
     42 // Compile a binary Fst from textual input, helper class for fstcompile.cc
     43 // WARNING: Stand-alone use of this class not recommended, most code should
     44 // read/write using the binary format which is much more efficient.
     45 template <class A> class FstCompiler {
     46  public:
     47   typedef A Arc;
     48   typedef typename A::StateId StateId;
     49   typedef typename A::Label Label;
     50   typedef typename A::Weight Weight;
     51 
     52   // WARNING: use of 'allow_negative_labels = true' not recommended; may
     53   // cause conflicts
     54   FstCompiler(istream &istrm, const string &source,
     55             const SymbolTable *isyms, const SymbolTable *osyms,
     56             const SymbolTable *ssyms, bool accep, bool ikeep,
     57               bool okeep, bool nkeep, bool allow_negative_labels = false)
     58       : nline_(0), source_(source),
     59         isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
     60         nstates_(0), keep_state_numbering_(nkeep),
     61         allow_negative_labels_(allow_negative_labels) {
     62     char line[kLineLen];
     63     while (istrm.getline(line, kLineLen)) {
     64       ++nline_;
     65       vector<char *> col;
     66       string separator = FLAGS_fst_field_separator + "\n";
     67       SplitToVector(line, separator.c_str(), &col, true);
     68       if (col.size() == 0 || col[0][0] == '\0')  // empty line
     69         continue;
     70       if (col.size() > 5 ||
     71           (col.size() > 4 && accep) ||
     72           (col.size() == 3 && !accep)) {
     73         FSTERROR() << "FstCompiler: Bad number of columns, source = "
     74                    << source_
     75                    << ", line = " << nline_;
     76         fst_.SetProperties(kError, kError);
     77         return;
     78       }
     79       StateId s = StrToStateId(col[0]);
     80       while (s >= fst_.NumStates())
     81         fst_.AddState();
     82       if (nline_ == 1)
     83         fst_.SetStart(s);
     84 
     85       Arc arc;
     86       StateId d = s;
     87       switch (col.size()) {
     88       case 1:
     89         fst_.SetFinal(s, Weight::One());
     90         break;
     91       case 2:
     92         fst_.SetFinal(s, StrToWeight(col[1], true));
     93         break;
     94       case 3:
     95         arc.nextstate = d = StrToStateId(col[1]);
     96         arc.ilabel = StrToILabel(col[2]);
     97         arc.olabel = arc.ilabel;
     98         arc.weight = Weight::One();
     99         fst_.AddArc(s, arc);
    100         break;
    101       case 4:
    102         arc.nextstate = d = StrToStateId(col[1]);
    103         arc.ilabel = StrToILabel(col[2]);
    104         if (accep) {
    105           arc.olabel = arc.ilabel;
    106           arc.weight = StrToWeight(col[3], false);
    107         } else {
    108           arc.olabel = StrToOLabel(col[3]);
    109           arc.weight = Weight::One();
    110         }
    111         fst_.AddArc(s, arc);
    112         break;
    113       case 5:
    114         arc.nextstate = d = StrToStateId(col[1]);
    115         arc.ilabel = StrToILabel(col[2]);
    116         arc.olabel = StrToOLabel(col[3]);
    117         arc.weight = StrToWeight(col[4], false);
    118         fst_.AddArc(s, arc);
    119       }
    120       while (d >= fst_.NumStates())
    121         fst_.AddState();
    122     }
    123     if (ikeep)
    124       fst_.SetInputSymbols(isyms);
    125     if (okeep)
    126       fst_.SetOutputSymbols(osyms);
    127   }
    128 
    129   const VectorFst<A> &Fst() const {
    130     return fst_;
    131   }
    132 
    133  private:
    134   // Maximum line length in text file.
    135   static const int kLineLen = 8096;
    136 
    137   int64 StrToId(const char *s, const SymbolTable *syms,
    138                 const char *name, bool allow_negative = false) const {
    139     int64 n = 0;
    140 
    141     if (syms) {
    142       n = syms->Find(s);
    143       if (n == -1 || (!allow_negative && n < 0)) {
    144         FSTERROR() << "FstCompiler: Symbol \"" << s
    145                    << "\" is not mapped to any integer " << name
    146                    << ", symbol table = " << syms->Name()
    147                    << ", source = " << source_ << ", line = " << nline_;
    148         fst_.SetProperties(kError, kError);
    149       }
    150     } else {
    151       char *p;
    152       n = strtoll(s, &p, 10);
    153       if (p < s + strlen(s) || (!allow_negative && n < 0)) {
    154         FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
    155                    << "\", source = " << source_ << ", line = " << nline_;
    156         fst_.SetProperties(kError, kError);
    157       }
    158     }
    159     return n;
    160   }
    161 
    162   StateId StrToStateId(const char *s) {
    163     StateId n = StrToId(s, ssyms_, "state ID");
    164 
    165     if (keep_state_numbering_)
    166       return n;
    167 
    168     // remap state IDs to make dense set
    169     typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n);
    170     if (it == states_.end()) {
    171       states_[n] = nstates_;
    172       return nstates_++;
    173     } else {
    174       return it->second;
    175     }
    176   }
    177 
    178   StateId StrToILabel(const char *s) const {
    179     return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
    180   }
    181 
    182   StateId StrToOLabel(const char *s) const {
    183     return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
    184   }
    185 
    186   Weight StrToWeight(const char *s, bool allow_zero) const {
    187     Weight w;
    188     istringstream strm(s);
    189     strm >> w;
    190     if (!strm || (!allow_zero && w == Weight::Zero())) {
    191       FSTERROR() << "FstCompiler: Bad weight = \"" << s
    192                  << "\", source = " << source_ << ", line = " << nline_;
    193       fst_.SetProperties(kError, kError);
    194       w = Weight::NoWeight();
    195     }
    196     return w;
    197   }
    198 
    199   mutable VectorFst<A> fst_;
    200   size_t nline_;
    201   string source_;                      // text FST source name
    202   const SymbolTable *isyms_;           // ilabel symbol table
    203   const SymbolTable *osyms_;           // olabel symbol table
    204   const SymbolTable *ssyms_;           // slabel symbol table
    205   unordered_map<StateId, StateId> states_;  // state ID map
    206   StateId nstates_;                    // number of seen states
    207   bool keep_state_numbering_;
    208   bool allow_negative_labels_;         // not recommended; may cause conflicts
    209 
    210   DISALLOW_COPY_AND_ASSIGN(FstCompiler);
    211 };
    212 
    213 }  // namespace fst
    214 
    215 #endif  // FST_SCRIPT_COMPILE_IMPL_H_
    216