1 2 // string.h 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 // Copyright 2005-2010 Google, Inc. 17 // Author: allauzen (at) google.com (Cyril Allauzen) 18 // 19 // \file 20 // Utilities to convert strings into FSTs. 21 // 22 23 #ifndef FST_LIB_STRING_H_ 24 #define FST_LIB_STRING_H_ 25 26 #include <fst/compact-fst.h> 27 #include <fst/mutable-fst.h> 28 29 DECLARE_string(fst_field_separator); 30 31 namespace fst { 32 33 // Functor compiling a string in an FST 34 template <class A> 35 class StringCompiler { 36 public: 37 typedef A Arc; 38 typedef typename A::Label Label; 39 typedef typename A::Weight Weight; 40 41 enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; 42 43 StringCompiler(TokenType type, const SymbolTable *syms = 0, 44 Label unknown_label = kNoLabel, 45 bool allow_negative = false) 46 : token_type_(type), syms_(syms), unknown_label_(unknown_label), 47 allow_negative_(allow_negative) {} 48 49 // Compile string 's' into FST 'fst'. 50 template <class F> 51 bool operator()(const string &s, F *fst) { 52 vector<Label> labels; 53 if (!ConvertStringToLabels(s, &labels)) 54 return false; 55 Compile(labels, fst); 56 return true; 57 } 58 59 private: 60 bool ConvertStringToLabels(const string &str, vector<Label> *labels) const { 61 labels->clear(); 62 if (token_type_ == BYTE) { 63 for (size_t i = 0; i < str.size(); ++i) 64 labels->push_back(static_cast<unsigned char>(str[i])); 65 } else if (token_type_ == UTF8) { 66 return UTF8StringToLabels(str, labels); 67 } else { 68 char *c_str = new char[str.size() + 1]; 69 str.copy(c_str, str.size()); 70 c_str[str.size()] = 0; 71 vector<char *> vec; 72 string separator = "\n" + FLAGS_fst_field_separator; 73 SplitToVector(c_str, separator.c_str(), &vec, true); 74 for (size_t i = 0; i < vec.size(); ++i) { 75 Label label; 76 if (!ConvertSymbolToLabel(vec[i], &label)) 77 return false; 78 labels->push_back(label); 79 } 80 delete[] c_str; 81 } 82 return true; 83 } 84 85 void Compile(const vector<Label> &labels, MutableFst<A> *fst) const { 86 fst->DeleteStates(); 87 while (fst->NumStates() <= labels.size()) 88 fst->AddState(); 89 for (size_t i = 0; i < labels.size(); ++i) 90 fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1)); 91 fst->SetStart(0); 92 fst->SetFinal(labels.size(), Weight::One()); 93 } 94 95 template <class Unsigned> 96 void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>, 97 Unsigned> *fst) const { 98 fst->SetCompactElements(labels.begin(), labels.end()); 99 } 100 101 bool ConvertSymbolToLabel(const char *s, Label* output) const { 102 int64 n; 103 if (syms_) { 104 n = syms_->Find(s); 105 if ((n == -1) && (unknown_label_ != kNoLabel)) 106 n = unknown_label_; 107 if (n == -1 || (!allow_negative_ && n < 0)) { 108 VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s 109 << "\" is not mapped to any integer label, symbol table = " 110 << syms_->Name(); 111 return false; 112 } 113 } else { 114 char *p; 115 n = strtoll(s, &p, 10); 116 if (p < s + strlen(s) || (!allow_negative_ && n < 0)) { 117 VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer " 118 << "= \"" << s << "\""; 119 return false; 120 } 121 } 122 *output = n; 123 return true; 124 } 125 126 TokenType token_type_; // Token type: symbol, byte or utf8 encoded 127 const SymbolTable *syms_; // Symbol table used when token type is symbol 128 Label unknown_label_; // Label for token missing from symbol table 129 bool allow_negative_; // Negative labels allowed? 130 131 DISALLOW_COPY_AND_ASSIGN(StringCompiler); 132 }; 133 134 // Functor to print a string FST as a string. 135 template <class A> 136 class StringPrinter { 137 public: 138 typedef A Arc; 139 typedef typename A::Label Label; 140 typedef typename A::StateId StateId; 141 typedef typename A::Weight Weight; 142 143 enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 }; 144 145 StringPrinter(TokenType token_type, 146 const SymbolTable *syms = 0) 147 : token_type_(token_type), syms_(syms) {} 148 149 // Convert the FST 'fst' into the string 'output' 150 bool operator()(const Fst<A> &fst, string *output) { 151 bool is_a_string = FstToLabels(fst); 152 if (!is_a_string) { 153 VLOG(1) << "StringPrinter::operator(): Fst is not a string."; 154 return false; 155 } 156 157 output->clear(); 158 159 if (token_type_ == SYMBOL) { 160 stringstream sstrm; 161 for (size_t i = 0; i < labels_.size(); ++i) { 162 if (i) 163 sstrm << *(FLAGS_fst_field_separator.rbegin()); 164 if (!PrintLabel(labels_[i], sstrm)) 165 return false; 166 } 167 *output = sstrm.str(); 168 } else if (token_type_ == BYTE) { 169 for (size_t i = 0; i < labels_.size(); ++i) { 170 output->push_back(labels_[i]); 171 } 172 } else if (token_type_ == UTF8) { 173 return LabelsToUTF8String(labels_, output); 174 } else { 175 VLOG(1) << "StringPrinter::operator(): Unknown token type: " 176 << token_type_; 177 return false; 178 } 179 return true; 180 } 181 182 private: 183 bool FstToLabels(const Fst<A> &fst) { 184 labels_.clear(); 185 186 StateId s = fst.Start(); 187 if (s == kNoStateId) { 188 VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for " 189 << "string fst."; 190 return false; 191 } 192 193 while (fst.Final(s) == Weight::Zero()) { 194 ArcIterator<Fst<A> > aiter(fst, s); 195 if (aiter.Done()) { 196 VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does " 197 << "not reach final state."; 198 return false; 199 } 200 201 const A& arc = aiter.Value(); 202 labels_.push_back(arc.olabel); 203 204 s = arc.nextstate; 205 if (s == kNoStateId) { 206 VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid " 207 << "state."; 208 return false; 209 } 210 211 aiter.Next(); 212 if (!aiter.Done()) { 213 VLOG(2) << "StringPrinter::FstToLabels: State with multiple " 214 << "outgoing arcs found."; 215 return false; 216 } 217 } 218 219 return true; 220 } 221 222 bool PrintLabel(Label lab, ostream& ostrm) { 223 if (syms_) { 224 string symbol = syms_->Find(lab); 225 if (symbol == "") { 226 VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not " 227 << "mapped to any textual symbol, symbol table = " 228 << syms_->Name(); 229 return false; 230 } 231 ostrm << symbol; 232 } else { 233 ostrm << lab; 234 } 235 return true; 236 } 237 238 TokenType token_type_; // Token type: symbol, byte or utf8 encoded 239 const SymbolTable *syms_; // Symbol table used when token type is symbol 240 vector<Label> labels_; // Input FST labels. 241 242 DISALLOW_COPY_AND_ASSIGN(StringPrinter); 243 }; 244 245 } // namespace fst 246 247 #endif // FST_LIB_STRING_H_ 248