Home | History | Annotate | Download | only in far
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Authors: allauzen (at) google.com (Cyril Allauzen)
     16 //          ttai (at) google.com (Terry Tai)
     17 //          jpr (at) google.com (Jake Ratkiewicz)
     18 
     19 
     20 #ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
     21 #define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
     22 
     23 #include <libgen.h>
     24 #include <string>
     25 #include <vector>
     26 using std::vector;
     27 
     28 #include <fst/extensions/far/far.h>
     29 #include <fst/string.h>
     30 
     31 namespace fst {
     32 
     33 // Construct a reader that provides FSTs from a file (stream) either on a
     34 // line-by-line basis or on a per-stream basis.  Note that the freshly
     35 // constructed reader is already set to the first input.
     36 //
     37 // Sample Usage:
     38 //   for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
     39 //     Fst *fst = reader.GetVectorFst();
     40 //   }
     41 template <class A>
     42 class StringReader {
     43  public:
     44   typedef A Arc;
     45   typedef typename A::Label Label;
     46   typedef typename A::Weight Weight;
     47   typedef typename StringCompiler<A>::TokenType TokenType;
     48 
     49   enum EntryType { LINE = 1, FILE = 2 };
     50 
     51   StringReader(istream &istrm,
     52                const string &source,
     53                EntryType entry_type,
     54                TokenType token_type,
     55                bool allow_negative_labels,
     56                const SymbolTable *syms = 0,
     57                Label unknown_label = kNoStateId)
     58       : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type),
     59         token_type_(token_type), symbols_(syms), done_(false),
     60         compiler_(token_type, syms, unknown_label, allow_negative_labels) {
     61     Next();  // Initialize the reader to the first input.
     62   }
     63 
     64   bool Done() {
     65     return done_;
     66   }
     67 
     68   void Next() {
     69     VLOG(1) << "Processing source " << source_ << " at line " << nline_;
     70     if (!strm_) {                    // We're done if we have no more input.
     71       done_ = true;
     72       return;
     73     }
     74     if (entry_type_ == LINE) {
     75       getline(strm_, content_);
     76       ++nline_;
     77     } else {
     78       content_.clear();
     79       string line;
     80       while (getline(strm_, line)) {
     81         ++nline_;
     82         content_.append(line);
     83         content_.append("\n");
     84       }
     85     }
     86     if (!strm_ && content_.empty())  // We're also done if we read off all the
     87       done_ = true;                  // whitespace at the end of a file.
     88   }
     89 
     90   VectorFst<A> *GetVectorFst(bool keep_symbols = false) {
     91     VectorFst<A> *fst = new VectorFst<A>;
     92     if (keep_symbols) {
     93       fst->SetInputSymbols(symbols_);
     94       fst->SetOutputSymbols(symbols_);
     95     }
     96     if (compiler_(content_, fst)) {
     97       return fst;
     98     } else {
     99       delete fst;
    100       return NULL;
    101     }
    102   }
    103 
    104   CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) {
    105     CompactFst<A, StringCompactor<A> > *fst;
    106     if (keep_symbols) {
    107       VectorFst<A> tmp;
    108       tmp.SetInputSymbols(symbols_);
    109       tmp.SetOutputSymbols(symbols_);
    110       fst = new CompactFst<A, StringCompactor<A> >(tmp);
    111     } else {
    112       fst = new CompactFst<A, StringCompactor<A> >;
    113     }
    114     if (compiler_(content_, fst)) {
    115       return fst;
    116     } else {
    117       delete fst;
    118       return NULL;
    119     }
    120   }
    121 
    122  private:
    123   size_t nline_;
    124   istream &strm_;
    125   string source_;
    126   EntryType entry_type_;
    127   TokenType token_type_;
    128   const SymbolTable *symbols_;
    129   bool done_;
    130   StringCompiler<A> compiler_;
    131   string content_;  // The actual content of the input stream's next FST.
    132 
    133   DISALLOW_COPY_AND_ASSIGN(StringReader);
    134 };
    135 
    136 // Compute the minimal length required to encode each line number as a decimal
    137 // number.
    138 int KeySize(const char *filename);
    139 
    140 template <class Arc>
    141 void FarCompileStrings(const vector<string> &in_fnames,
    142                        const string &out_fname,
    143                        const string &fst_type,
    144                        const FarType &far_type,
    145                        int32 generate_keys,
    146                        FarEntryType fet,
    147                        FarTokenType tt,
    148                        const string &symbols_fname,
    149                        const string &unknown_symbol,
    150                        bool keep_symbols,
    151                        bool initial_symbols,
    152                        bool allow_negative_labels,
    153                        bool file_list_input,
    154                        const string &key_prefix,
    155                        const string &key_suffix) {
    156   typename StringReader<Arc>::EntryType entry_type;
    157   if (fet == FET_LINE) {
    158     entry_type = StringReader<Arc>::LINE;
    159   } else if (fet == FET_FILE) {
    160     entry_type = StringReader<Arc>::FILE;
    161   } else {
    162     FSTERROR() << "FarCompileStrings: unknown entry type";
    163     return;
    164   }
    165 
    166   typename StringCompiler<Arc>::TokenType token_type;
    167   if (tt == FTT_SYMBOL) {
    168     token_type = StringCompiler<Arc>::SYMBOL;
    169   } else if (tt == FTT_BYTE) {
    170     token_type = StringCompiler<Arc>::BYTE;
    171   } else if (tt == FTT_UTF8) {
    172     token_type = StringCompiler<Arc>::UTF8;
    173   } else {
    174     FSTERROR() << "FarCompileStrings: unknown token type";
    175     return;
    176   }
    177 
    178   bool compact;
    179   if (fst_type.empty() || (fst_type == "vector")) {
    180     compact = false;
    181   } else if (fst_type == "compact") {
    182     compact = true;
    183   } else {
    184     FSTERROR() << "FarCompileStrings: unknown fst type: "
    185                << fst_type;
    186     return;
    187   }
    188 
    189   const SymbolTable *syms = 0;
    190   typename Arc::Label unknown_label = kNoLabel;
    191   if (!symbols_fname.empty()) {
    192     SymbolTableTextOptions opts;
    193     opts.allow_negative = allow_negative_labels;
    194     syms = SymbolTable::ReadText(symbols_fname, opts);
    195     if (!syms) {
    196       FSTERROR() << "FarCompileStrings: error reading symbol table: "
    197                  << symbols_fname;
    198       return;
    199     }
    200     if (!unknown_symbol.empty()) {
    201       unknown_label = syms->Find(unknown_symbol);
    202       if (unknown_label == kNoLabel) {
    203         FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label
    204                    << "\" missing from symbol table: " << symbols_fname;
    205         return;
    206       }
    207     }
    208   }
    209 
    210   FarWriter<Arc> *far_writer =
    211       FarWriter<Arc>::Create(out_fname, far_type);
    212   if (!far_writer) return;
    213 
    214   vector<string> inputs;
    215   if (file_list_input) {
    216     for (int i = 1; i < in_fnames.size(); ++i) {
    217       istream *istrm = in_fnames.empty() ? &cin :
    218           new ifstream(in_fnames[i].c_str());
    219       string str;
    220       while (getline(*istrm, str))
    221         inputs.push_back(str);
    222       if (!in_fnames.empty())
    223         delete istrm;
    224     }
    225   } else {
    226     inputs = in_fnames;
    227   }
    228 
    229   for (int i = 0, n = 0; i < inputs.size(); ++i) {
    230     if (generate_keys == 0 && inputs[i].empty()) {
    231       FSTERROR() << "FarCompileStrings: read from a file instead of stdin or"
    232                  << " set the --generate_keys flags.";
    233       delete far_writer;
    234       delete syms;
    235       return;
    236     }
    237     int key_size = generate_keys ? generate_keys :
    238         (entry_type == StringReader<Arc>::FILE ? 1 :
    239          KeySize(inputs[i].c_str()));
    240     istream *istrm = inputs[i].empty() ? &cin :
    241         new ifstream(inputs[i].c_str());
    242 
    243     bool keep_syms = keep_symbols;
    244     for (StringReader<Arc> reader(
    245              *istrm, inputs[i].empty() ? "stdin" : inputs[i],
    246              entry_type, token_type, allow_negative_labels,
    247              syms, unknown_label);
    248          !reader.Done();
    249          reader.Next()) {
    250       ++n;
    251       const Fst<Arc> *fst;
    252       if (compact)
    253         fst = reader.GetCompactFst(keep_syms);
    254       else
    255         fst = reader.GetVectorFst(keep_syms);
    256       if (initial_symbols)
    257         keep_syms = false;
    258       if (!fst) {
    259         FSTERROR() << "FarCompileStrings: compiling string number " << n
    260                    << " in file " << inputs[i] << " failed with token_type = "
    261                    << (tt == FTT_BYTE ? "byte" :
    262                        (tt == FTT_UTF8 ? "utf8" :
    263                         (tt == FTT_SYMBOL ? "symbol" : "unknown")))
    264                    << " and entry_type = "
    265                    << (fet == FET_LINE ? "line" :
    266                        (fet == FET_FILE ? "file" : "unknown"));
    267         delete far_writer;
    268         delete syms;
    269         if (!inputs[i].empty()) delete istrm;
    270         return;
    271       }
    272       ostringstream keybuf;
    273       keybuf.width(key_size);
    274       keybuf.fill('0');
    275       keybuf << n;
    276       string key;
    277       if (generate_keys > 0) {
    278         key = keybuf.str();
    279       } else {
    280         char* filename = new char[inputs[i].size() + 1];
    281         strcpy(filename, inputs[i].c_str());
    282         key = basename(filename);
    283         if (entry_type != StringReader<Arc>::FILE) {
    284           key += "-";
    285           key += keybuf.str();
    286         }
    287         delete[] filename;
    288       }
    289       far_writer->Add(key_prefix + key + key_suffix, *fst);
    290       delete fst;
    291     }
    292     if (generate_keys == 0)
    293       n = 0;
    294     if (!inputs[i].empty())
    295       delete istrm;
    296   }
    297 
    298   delete far_writer;
    299 }
    300 
    301 }  // namespace fst
    302 
    303 
    304 #endif  // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
    305