Home | History | Annotate | Download | only in fst
      1 
      2 // string.h
      3 
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //     http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 // Copyright 2005-2010 Google, Inc.
     17 // Author: allauzen (at) google.com (Cyril Allauzen)
     18 //
     19 // \file
     20 // Utilities to convert strings into FSTs.
     21 //
     22 
     23 #ifndef FST_LIB_STRING_H_
     24 #define FST_LIB_STRING_H_
     25 
     26 #include <fst/compact-fst.h>
     27 #include <fst/mutable-fst.h>
     28 
     29 DECLARE_string(fst_field_separator);
     30 
     31 namespace fst {
     32 
     33 // Functor compiling a string in an FST
     34 template <class A>
     35 class StringCompiler {
     36  public:
     37   typedef A Arc;
     38   typedef typename A::Label Label;
     39   typedef typename A::Weight Weight;
     40 
     41   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
     42 
     43   StringCompiler(TokenType type, const SymbolTable *syms = 0,
     44                  Label unknown_label = kNoLabel,
     45                  bool allow_negative = false)
     46       : token_type_(type), syms_(syms), unknown_label_(unknown_label),
     47         allow_negative_(allow_negative) {}
     48 
     49   // Compile string 's' into FST 'fst'.
     50   template <class F>
     51   bool operator()(const string &s, F *fst) {
     52     vector<Label> labels;
     53     if (!ConvertStringToLabels(s, &labels))
     54       return false;
     55     Compile(labels, fst);
     56     return true;
     57   }
     58 
     59  private:
     60   bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
     61     labels->clear();
     62     if (token_type_ == BYTE) {
     63       for (size_t i = 0; i < str.size(); ++i)
     64         labels->push_back(static_cast<unsigned char>(str[i]));
     65     } else if (token_type_ == UTF8) {
     66       return UTF8StringToLabels(str, labels);
     67     } else {
     68       char *c_str = new char[str.size() + 1];
     69       str.copy(c_str, str.size());
     70       c_str[str.size()] = 0;
     71       vector<char *> vec;
     72       string separator = "\n" + FLAGS_fst_field_separator;
     73       SplitToVector(c_str, separator.c_str(), &vec, true);
     74       for (size_t i = 0; i < vec.size(); ++i) {
     75         Label label;
     76         if (!ConvertSymbolToLabel(vec[i], &label))
     77           return false;
     78         labels->push_back(label);
     79       }
     80       delete[] c_str;
     81     }
     82     return true;
     83   }
     84 
     85   void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
     86     fst->DeleteStates();
     87     while (fst->NumStates() <= labels.size())
     88       fst->AddState();
     89     for (size_t i = 0; i < labels.size(); ++i)
     90       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
     91     fst->SetStart(0);
     92     fst->SetFinal(labels.size(), Weight::One());
     93   }
     94 
     95   template <class Unsigned>
     96   void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
     97                Unsigned> *fst) const {
     98     fst->SetCompactElements(labels.begin(), labels.end());
     99   }
    100 
    101   bool ConvertSymbolToLabel(const char *s, Label* output) const {
    102     int64 n;
    103     if (syms_) {
    104       n = syms_->Find(s);
    105       if ((n == -1) && (unknown_label_ != kNoLabel))
    106         n = unknown_label_;
    107       if (n == -1 || (!allow_negative_ && n < 0)) {
    108         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
    109                 << "\" is not mapped to any integer label, symbol table = "
    110                  << syms_->Name();
    111         return false;
    112       }
    113     } else {
    114       char *p;
    115       n = strtoll(s, &p, 10);
    116       if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
    117         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
    118                 << "= \"" << s << "\"";
    119         return false;
    120       }
    121     }
    122     *output = n;
    123     return true;
    124   }
    125 
    126   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
    127   const SymbolTable *syms_;  // Symbol table used when token type is symbol
    128   Label unknown_label_;      // Label for token missing from symbol table
    129   bool allow_negative_;      // Negative labels allowed?
    130 
    131   DISALLOW_COPY_AND_ASSIGN(StringCompiler);
    132 };
    133 
    134 // Functor to print a string FST as a string.
    135 template <class A>
    136 class StringPrinter {
    137  public:
    138   typedef A Arc;
    139   typedef typename A::Label Label;
    140   typedef typename A::StateId StateId;
    141   typedef typename A::Weight Weight;
    142 
    143   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
    144 
    145   StringPrinter(TokenType token_type,
    146                 const SymbolTable *syms = 0)
    147       : token_type_(token_type), syms_(syms) {}
    148 
    149   // Convert the FST 'fst' into the string 'output'
    150   bool operator()(const Fst<A> &fst, string *output) {
    151     bool is_a_string = FstToLabels(fst);
    152     if (!is_a_string) {
    153       VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
    154       return false;
    155     }
    156 
    157     output->clear();
    158 
    159     if (token_type_ == SYMBOL) {
    160       stringstream sstrm;
    161       for (size_t i = 0; i < labels_.size(); ++i) {
    162         if (i)
    163           sstrm << *(FLAGS_fst_field_separator.rbegin());
    164         if (!PrintLabel(labels_[i], sstrm))
    165           return false;
    166       }
    167       *output = sstrm.str();
    168     } else if (token_type_ == BYTE) {
    169       for (size_t i = 0; i < labels_.size(); ++i) {
    170         output->push_back(labels_[i]);
    171       }
    172     } else if (token_type_ == UTF8) {
    173       return LabelsToUTF8String(labels_, output);
    174     } else {
    175       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
    176               << token_type_;
    177       return false;
    178     }
    179     return true;
    180   }
    181 
    182  private:
    183   bool FstToLabels(const Fst<A> &fst) {
    184     labels_.clear();
    185 
    186     StateId s = fst.Start();
    187     if (s == kNoStateId) {
    188       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
    189               << "string fst.";
    190       return false;
    191     }
    192 
    193     while (fst.Final(s) == Weight::Zero()) {
    194       ArcIterator<Fst<A> > aiter(fst, s);
    195       if (aiter.Done()) {
    196         VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
    197                 << "not reach final state.";
    198         return false;
    199       }
    200 
    201       const A& arc = aiter.Value();
    202       labels_.push_back(arc.olabel);
    203 
    204       s = arc.nextstate;
    205       if (s == kNoStateId) {
    206         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
    207                 << "state.";
    208         return false;
    209       }
    210 
    211       aiter.Next();
    212       if (!aiter.Done()) {
    213         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
    214                 << "outgoing arcs found.";
    215         return false;
    216       }
    217     }
    218 
    219     return true;
    220   }
    221 
    222   bool PrintLabel(Label lab, ostream& ostrm) {
    223     if (syms_) {
    224       string symbol = syms_->Find(lab);
    225       if (symbol == "") {
    226         VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
    227                 << "mapped to any textual symbol, symbol table = "
    228                  << syms_->Name();
    229         return false;
    230       }
    231       ostrm << symbol;
    232     } else {
    233       ostrm << lab;
    234     }
    235     return true;
    236   }
    237 
    238   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
    239   const SymbolTable *syms_;  // Symbol table used when token type is symbol
    240   vector<Label> labels_;     // Input FST labels.
    241 
    242   DISALLOW_COPY_AND_ASSIGN(StringPrinter);
    243 };
    244 
    245 }  // namespace fst
    246 
    247 #endif // FST_LIB_STRING_H_
    248