Home | History | Annotate | Download | only in fst
      1 
      2 // string.h
      3 
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //     http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 // Copyright 2005-2010 Google, Inc.
     17 // Author: allauzen (at) google.com (Cyril Allauzen)
     18 //
     19 // \file
     20 // Utilities to convert strings into FSTs.
     21 //
     22 
     23 #ifndef FST_LIB_STRING_H_
     24 #define FST_LIB_STRING_H_
     25 
     26 #include <fst/compact-fst.h>
     27 #include <fst/icu.h>
     28 #include <fst/mutable-fst.h>
     29 
     30 DECLARE_string(fst_field_separator);
     31 
     32 namespace fst {
     33 
     34 // Functor compiling a string in an FST
     35 template <class A>
     36 class StringCompiler {
     37  public:
     38   typedef A Arc;
     39   typedef typename A::Label Label;
     40   typedef typename A::Weight Weight;
     41 
     42   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
     43 
     44   StringCompiler(TokenType type, const SymbolTable *syms = 0,
     45                  Label unknown_label = kNoLabel,
     46                  bool allow_negative = false)
     47       : token_type_(type), syms_(syms), unknown_label_(unknown_label),
     48         allow_negative_(allow_negative) {}
     49 
     50   // Compile string 's' into FST 'fst'.
     51   template <class F>
     52   bool operator()(const string &s, F *fst) const {
     53     vector<Label> labels;
     54     if (!ConvertStringToLabels(s, &labels))
     55       return false;
     56     Compile(labels, fst);
     57     return true;
     58   }
     59 
     60   template <class F>
     61   bool operator()(const string &s, F *fst, Weight w) const {
     62     vector<Label> labels;
     63     if (!ConvertStringToLabels(s, &labels))
     64       return false;
     65     Compile(labels, fst, w);
     66     return true;
     67   }
     68 
     69  private:
     70   bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
     71     labels->clear();
     72     if (token_type_ == BYTE) {
     73       for (size_t i = 0; i < str.size(); ++i)
     74         labels->push_back(static_cast<unsigned char>(str[i]));
     75     } else if (token_type_ == UTF8) {
     76       return UTF8StringToLabels(str, labels);
     77     } else {
     78       char *c_str = new char[str.size() + 1];
     79       str.copy(c_str, str.size());
     80       c_str[str.size()] = 0;
     81       vector<char *> vec;
     82       string separator = "\n" + FLAGS_fst_field_separator;
     83       SplitToVector(c_str, separator.c_str(), &vec, true);
     84       for (size_t i = 0; i < vec.size(); ++i) {
     85         Label label;
     86         if (!ConvertSymbolToLabel(vec[i], &label))
     87           return false;
     88         labels->push_back(label);
     89       }
     90       delete[] c_str;
     91     }
     92     return true;
     93   }
     94 
     95   void Compile(const vector<Label> &labels, MutableFst<A> *fst,
     96                const Weight &weight = Weight::One()) const {
     97     fst->DeleteStates();
     98     while (fst->NumStates() <= labels.size())
     99       fst->AddState();
    100     for (size_t i = 0; i < labels.size(); ++i)
    101       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
    102     fst->SetStart(0);
    103     fst->SetFinal(labels.size(), weight);
    104   }
    105 
    106   template <class Unsigned>
    107   void Compile(const vector<Label> &labels,
    108                CompactFst<A, StringCompactor<A>, Unsigned> *fst) const {
    109     fst->SetCompactElements(labels.begin(), labels.end());
    110   }
    111 
    112   template <class Unsigned>
    113   void Compile(const vector<Label> &labels,
    114                CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst,
    115                const Weight &weight = Weight::One()) const {
    116     vector<pair<Label, Weight> > compacts;
    117     compacts.reserve(labels.size());
    118     for (size_t i = 0; i < labels.size(); ++i)
    119       compacts.push_back(make_pair(labels[i], Weight::One()));
    120     compacts.back().second = weight;
    121     fst->SetCompactElements(compacts.begin(), compacts.end());
    122   }
    123 
    124   bool ConvertSymbolToLabel(const char *s, Label* output) const {
    125     int64 n;
    126     if (syms_) {
    127       n = syms_->Find(s);
    128       if ((n == -1) && (unknown_label_ != kNoLabel))
    129         n = unknown_label_;
    130       if (n == -1 || (!allow_negative_ && n < 0)) {
    131         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
    132                 << "\" is not mapped to any integer label, symbol table = "
    133                  << syms_->Name();
    134         return false;
    135       }
    136     } else {
    137       char *p;
    138       n = strtoll(s, &p, 10);
    139       if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
    140         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
    141                 << "= \"" << s << "\"";
    142         return false;
    143       }
    144     }
    145     *output = n;
    146     return true;
    147   }
    148 
    149   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
    150   const SymbolTable *syms_;  // Symbol table used when token type is symbol
    151   Label unknown_label_;      // Label for token missing from symbol table
    152   bool allow_negative_;      // Negative labels allowed?
    153 
    154   DISALLOW_COPY_AND_ASSIGN(StringCompiler);
    155 };
    156 
    157 // Functor to print a string FST as a string.
    158 template <class A>
    159 class StringPrinter {
    160  public:
    161   typedef A Arc;
    162   typedef typename A::Label Label;
    163   typedef typename A::StateId StateId;
    164   typedef typename A::Weight Weight;
    165 
    166   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
    167 
    168   StringPrinter(TokenType token_type,
    169                 const SymbolTable *syms = 0)
    170       : token_type_(token_type), syms_(syms) {}
    171 
    172   // Convert the FST 'fst' into the string 'output'
    173   bool operator()(const Fst<A> &fst, string *output) {
    174     bool is_a_string = FstToLabels(fst);
    175     if (!is_a_string) {
    176       VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
    177       return false;
    178     }
    179 
    180     output->clear();
    181 
    182     if (token_type_ == SYMBOL) {
    183       stringstream sstrm;
    184       for (size_t i = 0; i < labels_.size(); ++i) {
    185         if (i)
    186           sstrm << *(FLAGS_fst_field_separator.rbegin());
    187         if (!PrintLabel(labels_[i], sstrm))
    188           return false;
    189       }
    190       *output = sstrm.str();
    191     } else if (token_type_ == BYTE) {
    192       output->reserve(labels_.size());
    193       for (size_t i = 0; i < labels_.size(); ++i) {
    194         output->push_back(labels_[i]);
    195       }
    196     } else if (token_type_ == UTF8) {
    197       return LabelsToUTF8String(labels_, output);
    198     } else {
    199       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
    200               << token_type_;
    201       return false;
    202     }
    203     return true;
    204   }
    205 
    206  private:
    207   bool FstToLabels(const Fst<A> &fst) {
    208     labels_.clear();
    209 
    210     StateId s = fst.Start();
    211     if (s == kNoStateId) {
    212       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
    213               << "string fst.";
    214       return false;
    215     }
    216 
    217     while (fst.Final(s) == Weight::Zero()) {
    218       ArcIterator<Fst<A> > aiter(fst, s);
    219       if (aiter.Done()) {
    220         VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
    221                 << "not reach final state.";
    222         return false;
    223       }
    224 
    225       const A& arc = aiter.Value();
    226       labels_.push_back(arc.olabel);
    227 
    228       s = arc.nextstate;
    229       if (s == kNoStateId) {
    230         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
    231                 << "state.";
    232         return false;
    233       }
    234 
    235       aiter.Next();
    236       if (!aiter.Done()) {
    237         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
    238                 << "outgoing arcs found.";
    239         return false;
    240       }
    241     }
    242 
    243     return true;
    244   }
    245 
    246   bool PrintLabel(Label lab, ostream& ostrm) {
    247     if (syms_) {
    248       string symbol = syms_->Find(lab);
    249       if (symbol == "") {
    250         VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
    251                 << "mapped to any textual symbol, symbol table = "
    252                  << syms_->Name();
    253         return false;
    254       }
    255       ostrm << symbol;
    256     } else {
    257       ostrm << lab;
    258     }
    259     return true;
    260   }
    261 
    262   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
    263   const SymbolTable *syms_;  // Symbol table used when token type is symbol
    264   vector<Label> labels_;     // Input FST labels.
    265 
    266   DISALLOW_COPY_AND_ASSIGN(StringPrinter);
    267 };
    268 
    269 }  // namespace fst
    270 
    271 #endif // FST_LIB_STRING_H_
    272