1 // replace-util.h 2 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 // Copyright 2005-2010 Google, Inc. 17 // Author: riley (at) google.com (Michael Riley) 18 // 19 20 // \file 21 // Utility classes for the recursive replacement of Fsts (RTNs). 22 23 #ifndef FST_LIB_REPLACE_UTIL_H__ 24 #define FST_LIB_REPLACE_UTIL_H__ 25 26 #include <vector> 27 using std::vector; 28 #include <tr1/unordered_map> 29 using std::tr1::unordered_map; 30 using std::tr1::unordered_multimap; 31 #include <tr1/unordered_set> 32 using std::tr1::unordered_set; 33 using std::tr1::unordered_multiset; 34 #include <map> 35 36 #include <fst/connect.h> 37 #include <fst/mutable-fst.h> 38 #include <fst/topsort.h> 39 40 41 namespace fst { 42 43 template <class Arc> 44 void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&, 45 MutableFst<Arc> *, typename Arc::Label, bool); 46 47 48 // Utility class for the recursive replacement of Fsts (RTNs). The 49 // user provides a set of Label, Fst pairs at construction. These are 50 // used by methods for testing cyclic dependencies and connectedness 51 // and doing RTN connection and specific Fst replacement by label or 52 // for various optimization properties. The modified results can be 53 // obtained with the GetFstPairs() or GetMutableFstPairs() methods. 54 template <class Arc> 55 class ReplaceUtil { 56 public: 57 typedef typename Arc::Label Label; 58 typedef typename Arc::Weight Weight; 59 typedef typename Arc::StateId StateId; 60 61 typedef pair<Label, const Fst<Arc>*> FstPair; 62 typedef pair<Label, MutableFst<Arc>*> MutableFstPair; 63 typedef unordered_map<Label, Label> NonTerminalHash; 64 65 // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil. 66 ReplaceUtil(const vector<MutableFstPair> &fst_pairs, 67 Label root_label, bool epsilon_on_replace = false); 68 69 // Constructs from Fsts; Fst ownership retained by caller. 70 ReplaceUtil(const vector<FstPair> &fst_pairs, 71 Label root_label, bool epsilon_on_replace = false); 72 73 // Constructs from ReplaceFst internals; ownership retained by caller. 74 ReplaceUtil(const vector<const Fst<Arc> *> &fst_array, 75 const NonTerminalHash &nonterminal_hash, Label root_fst, 76 bool epsilon_on_replace = false); 77 78 ~ReplaceUtil() { 79 for (Label i = 0; i < fst_array_.size(); ++i) 80 delete fst_array_[i]; 81 } 82 83 // True if the non-terminal dependencies are cyclic. Cyclic 84 // dependencies will result in an unexpandable replace fst. 85 bool CyclicDependencies() const { 86 GetDependencies(false); 87 return depprops_ & kCyclic; 88 } 89 90 // Returns true if no useless Fsts, states or transitions. 91 bool Connected() const { 92 GetDependencies(false); 93 uint64 props = kAccessible | kCoAccessible; 94 for (Label i = 0; i < fst_array_.size(); ++i) { 95 if (!fst_array_[i]) 96 continue; 97 if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) 98 return false; 99 } 100 return true; 101 } 102 103 // Removes useless Fsts, states and transitions. 104 void Connect(); 105 106 // Replaces Fsts specified by labels. 107 // Does nothing if there are cyclic dependencies. 108 void ReplaceLabels(const vector<Label> &labels); 109 110 // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and 111 // 'nnonterm' non-terminals (updating in reverse dependency order). 112 // Does nothing if there are cyclic dependencies. 113 void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms); 114 115 // Replaces singleton Fsts. 116 // Does nothing if there are cyclic dependencies. 117 void ReplaceTrivial() { ReplaceBySize(2, 1, 1); } 118 119 // Replaces non-terminals that have at most 'ninstances' instances 120 // (updating in dependency order). 121 // Does nothing if there are cyclic dependencies. 122 void ReplaceByInstances(size_t ninstances); 123 124 // Replaces non-terminals that have only one instance. 125 // Does nothing if there are cyclic dependencies. 126 void ReplaceUnique() { ReplaceByInstances(1); } 127 128 // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil. 129 void GetFstPairs(vector<FstPair> *fst_pairs); 130 131 // Returns Label, MutableFst pairs; Fst ownership given to caller. 132 void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs); 133 134 private: 135 // Per Fst statistics 136 struct ReplaceStats { 137 StateId nstates; // # of states 138 StateId nfinal; // # of final states 139 size_t narcs; // # of arcs 140 Label nnonterms; // # of non-terminals in Fst 141 size_t nref; // # of non-terminal instances referring to this Fst 142 143 // # of times that ith Fst references this Fst 144 map<Label, size_t> inref; 145 // # of times that this Fst references the ith Fst 146 map<Label, size_t> outref; 147 148 ReplaceStats() 149 : nstates(0), 150 nfinal(0), 151 narcs(0), 152 nnonterms(0), 153 nref(0) {} 154 }; 155 156 // Check Mutable Fsts exist o.w. create them. 157 void CheckMutableFsts(); 158 159 // Computes the dependency graph of the replace Fsts. 160 // If 'stats' is true, dependency statistics computed as well. 161 void GetDependencies(bool stats) const; 162 163 void ClearDependencies() const { 164 depfst_.DeleteStates(); 165 stats_.clear(); 166 depprops_ = 0; 167 have_stats_ = false; 168 } 169 170 // Get topological order of dependencies. Returns false with cyclic input. 171 bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const; 172 173 // Update statistics assuming that jth Fst will be replaced. 174 void UpdateStats(Label j); 175 176 Label root_label_; // root non-terminal 177 Label root_fst_; // root Fst ID 178 bool epsilon_on_replace_; // see Replace() 179 vector<const Fst<Arc> *> fst_array_; // Fst per ID 180 vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID 181 vector<Label> nonterminal_array_; // Fst ID to non-terminal 182 NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID 183 mutable VectorFst<Arc> depfst_; // Fst ID dependencies 184 mutable vector<bool> depaccess_; // Fst ID accessibility 185 mutable uint64 depprops_; // dependency Fst props 186 mutable bool have_stats_; // have dependency statistics 187 mutable vector<ReplaceStats> stats_; // Per Fst statistics 188 DISALLOW_COPY_AND_ASSIGN(ReplaceUtil); 189 }; 190 191 template <class Arc> 192 ReplaceUtil<Arc>::ReplaceUtil( 193 const vector<MutableFstPair> &fst_pairs, 194 Label root_label, bool epsilon_on_replace) 195 : root_label_(root_label), 196 epsilon_on_replace_(epsilon_on_replace), 197 depprops_(0), 198 have_stats_(false) { 199 fst_array_.push_back(0); 200 mutable_fst_array_.push_back(0); 201 nonterminal_array_.push_back(kNoLabel); 202 for (Label i = 0; i < fst_pairs.size(); ++i) { 203 Label label = fst_pairs[i].first; 204 MutableFst<Arc> *fst = fst_pairs[i].second; 205 nonterminal_hash_[label] = fst_array_.size(); 206 nonterminal_array_.push_back(label); 207 fst_array_.push_back(fst); 208 mutable_fst_array_.push_back(fst); 209 } 210 root_fst_ = nonterminal_hash_[root_label_]; 211 if (!root_fst_) 212 FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; 213 } 214 215 template <class Arc> 216 ReplaceUtil<Arc>::ReplaceUtil( 217 const vector<FstPair> &fst_pairs, 218 Label root_label, bool epsilon_on_replace) 219 : root_label_(root_label), 220 epsilon_on_replace_(epsilon_on_replace), 221 depprops_(0), 222 have_stats_(false) { 223 fst_array_.push_back(0); 224 nonterminal_array_.push_back(kNoLabel); 225 for (Label i = 0; i < fst_pairs.size(); ++i) { 226 Label label = fst_pairs[i].first; 227 const Fst<Arc> *fst = fst_pairs[i].second; 228 nonterminal_hash_[label] = fst_array_.size(); 229 nonterminal_array_.push_back(label); 230 fst_array_.push_back(fst->Copy()); 231 } 232 root_fst_ = nonterminal_hash_[root_label]; 233 if (!root_fst_) 234 FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; 235 } 236 237 template <class Arc> 238 ReplaceUtil<Arc>::ReplaceUtil( 239 const vector<const Fst<Arc> *> &fst_array, 240 const NonTerminalHash &nonterminal_hash, Label root_fst, 241 bool epsilon_on_replace) 242 : root_fst_(root_fst), 243 epsilon_on_replace_(epsilon_on_replace), 244 nonterminal_array_(fst_array.size()), 245 nonterminal_hash_(nonterminal_hash), 246 depprops_(0), 247 have_stats_(false) { 248 fst_array_.push_back(0); 249 for (Label i = 1; i < fst_array.size(); ++i) 250 fst_array_.push_back(fst_array[i]->Copy()); 251 for (typename NonTerminalHash::const_iterator it = 252 nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) 253 nonterminal_array_[it->second] = it->first; 254 root_label_ = nonterminal_array_[root_fst_]; 255 } 256 257 template <class Arc> 258 void ReplaceUtil<Arc>::GetDependencies(bool stats) const { 259 if (depfst_.NumStates() > 0) { 260 if (stats && !have_stats_) 261 ClearDependencies(); 262 else 263 return; 264 } 265 266 have_stats_ = stats; 267 if (have_stats_) 268 stats_.reserve(fst_array_.size()); 269 270 for (Label i = 0; i < fst_array_.size(); ++i) { 271 depfst_.AddState(); 272 depfst_.SetFinal(i, Weight::One()); 273 if (have_stats_) 274 stats_.push_back(ReplaceStats()); 275 } 276 depfst_.SetStart(root_fst_); 277 278 // An arc from each state (representing the fst) to the 279 // state representing the fst being replaced 280 for (Label i = 0; i < fst_array_.size(); ++i) { 281 const Fst<Arc> *ifst = fst_array_[i]; 282 if (!ifst) 283 continue; 284 for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { 285 StateId s = siter.Value(); 286 if (have_stats_) { 287 ++stats_[i].nstates; 288 if (ifst->Final(s) != Weight::Zero()) 289 ++stats_[i].nfinal; 290 } 291 for (ArcIterator<Fst<Arc> > aiter(*ifst, s); 292 !aiter.Done(); aiter.Next()) { 293 if (have_stats_) 294 ++stats_[i].narcs; 295 const Arc& arc = aiter.Value(); 296 297 typename NonTerminalHash::const_iterator it = 298 nonterminal_hash_.find(arc.olabel); 299 if (it != nonterminal_hash_.end()) { 300 Label j = it->second; 301 depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j)); 302 if (have_stats_) { 303 ++stats_[i].nnonterms; 304 ++stats_[j].nref; 305 ++stats_[j].inref[i]; 306 ++stats_[i].outref[j]; 307 } 308 } 309 } 310 } 311 } 312 313 // Gets accessibility info 314 SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_); 315 DfsVisit(depfst_, &scc_visitor); 316 } 317 318 template <class Arc> 319 void ReplaceUtil<Arc>::UpdateStats(Label j) { 320 if (!have_stats_) { 321 FSTERROR() << "ReplaceUtil::UpdateStats: stats not available"; 322 return; 323 } 324 325 if (j == root_fst_) // can't replace root 326 return; 327 328 typedef typename map<Label, size_t>::iterator Iter; 329 for (Iter in = stats_[j].inref.begin(); 330 in != stats_[j].inref.end(); 331 ++in) { 332 Label i = in->first; 333 size_t ni = in->second; 334 stats_[i].nstates += stats_[j].nstates * ni; 335 stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps) 336 stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; 337 stats_[i].outref.erase(stats_[i].outref.find(j)); 338 for (Iter out = stats_[j].outref.begin(); 339 out != stats_[j].outref.end(); 340 ++out) { 341 Label k = out->first; 342 size_t nk = out->second; 343 stats_[i].outref[k] += ni * nk; 344 } 345 } 346 347 for (Iter out = stats_[j].outref.begin(); 348 out != stats_[j].outref.end(); 349 ++out) { 350 Label k = out->first; 351 size_t nk = out->second; 352 stats_[k].nref -= nk; 353 stats_[k].inref.erase(stats_[k].inref.find(j)); 354 for (Iter in = stats_[j].inref.begin(); 355 in != stats_[j].inref.end(); 356 ++in) { 357 Label i = in->first; 358 size_t ni = in->second; 359 stats_[k].inref[i] += ni * nk; 360 stats_[k].nref += ni * nk; 361 } 362 } 363 } 364 365 template <class Arc> 366 void ReplaceUtil<Arc>::CheckMutableFsts() { 367 if (mutable_fst_array_.size() == 0) { 368 for (Label i = 0; i < fst_array_.size(); ++i) { 369 if (!fst_array_[i]) { 370 mutable_fst_array_.push_back(0); 371 } else { 372 mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); 373 delete fst_array_[i]; 374 fst_array_[i] = mutable_fst_array_[i]; 375 } 376 } 377 } 378 } 379 380 template <class Arc> 381 void ReplaceUtil<Arc>::Connect() { 382 CheckMutableFsts(); 383 uint64 props = kAccessible | kCoAccessible; 384 for (Label i = 0; i < mutable_fst_array_.size(); ++i) { 385 if (!mutable_fst_array_[i]) 386 continue; 387 if (mutable_fst_array_[i]->Properties(props, false) != props) 388 fst::Connect(mutable_fst_array_[i]); 389 } 390 GetDependencies(false); 391 for (Label i = 0; i < mutable_fst_array_.size(); ++i) { 392 MutableFst<Arc> *fst = mutable_fst_array_[i]; 393 if (fst && !depaccess_[i]) { 394 delete fst; 395 fst_array_[i] = 0; 396 mutable_fst_array_[i] = 0; 397 } 398 } 399 ClearDependencies(); 400 } 401 402 template <class Arc> 403 bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, 404 vector<Label> *toporder) const { 405 // Finds topological order of dependencies. 406 vector<StateId> order; 407 bool acyclic = false; 408 409 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); 410 DfsVisit(fst, &top_order_visitor); 411 if (!acyclic) { 412 LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; 413 return false; 414 } 415 416 toporder->resize(order.size()); 417 for (Label i = 0; i < order.size(); ++i) 418 (*toporder)[order[i]] = i; 419 420 return true; 421 } 422 423 template <class Arc> 424 void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) { 425 CheckMutableFsts(); 426 unordered_set<Label> label_set; 427 for (Label i = 0; i < labels.size(); ++i) 428 if (labels[i] != root_label_) // can't replace root 429 label_set.insert(labels[i]); 430 431 // Finds Fst dependencies restricted to the labels requested. 432 GetDependencies(false); 433 VectorFst<Arc> pfst(depfst_); 434 for (StateId i = 0; i < pfst.NumStates(); ++i) { 435 vector<Arc> arcs; 436 for (ArcIterator< VectorFst<Arc> > aiter(pfst, i); 437 !aiter.Done(); aiter.Next()) { 438 const Arc &arc = aiter.Value(); 439 Label label = nonterminal_array_[arc.nextstate]; 440 if (label_set.count(label) > 0) 441 arcs.push_back(arc); 442 } 443 pfst.DeleteArcs(i); 444 for (size_t j = 0; j < arcs.size(); ++j) 445 pfst.AddArc(i, arcs[j]); 446 } 447 448 vector<Label> toporder; 449 if (!GetTopOrder(pfst, &toporder)) { 450 ClearDependencies(); 451 return; 452 } 453 454 // Visits Fsts in reverse topological order of dependencies and 455 // performs replacements. 456 for (Label o = toporder.size() - 1; o >= 0; --o) { 457 vector<FstPair> fst_pairs; 458 StateId s = toporder[o]; 459 for (ArcIterator< VectorFst<Arc> > aiter(pfst, s); 460 !aiter.Done(); aiter.Next()) { 461 const Arc &arc = aiter.Value(); 462 Label label = nonterminal_array_[arc.nextstate]; 463 const Fst<Arc> *fst = fst_array_[arc.nextstate]; 464 fst_pairs.push_back(make_pair(label, fst)); 465 } 466 if (fst_pairs.empty()) 467 continue; 468 Label label = nonterminal_array_[s]; 469 const Fst<Arc> *fst = fst_array_[s]; 470 fst_pairs.push_back(make_pair(label, fst)); 471 472 Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_); 473 } 474 ClearDependencies(); 475 } 476 477 template <class Arc> 478 void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, 479 size_t nnonterms) { 480 vector<Label> labels; 481 GetDependencies(true); 482 483 vector<Label> toporder; 484 if (!GetTopOrder(depfst_, &toporder)) { 485 ClearDependencies(); 486 return; 487 } 488 489 for (Label o = toporder.size() - 1; o >= 0; --o) { 490 Label j = toporder[o]; 491 if (stats_[j].nstates <= nstates && 492 stats_[j].narcs <= narcs && 493 stats_[j].nnonterms <= nnonterms) { 494 labels.push_back(nonterminal_array_[j]); 495 UpdateStats(j); 496 } 497 } 498 ReplaceLabels(labels); 499 } 500 501 template <class Arc> 502 void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { 503 vector<Label> labels; 504 GetDependencies(true); 505 506 vector<Label> toporder; 507 if (!GetTopOrder(depfst_, &toporder)) { 508 ClearDependencies(); 509 return; 510 } 511 for (Label o = 0; o < toporder.size(); ++o) { 512 Label j = toporder[o]; 513 if (stats_[j].nref <= ninstances) { 514 labels.push_back(nonterminal_array_[j]); 515 UpdateStats(j); 516 } 517 } 518 ReplaceLabels(labels); 519 } 520 521 template <class Arc> 522 void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) { 523 CheckMutableFsts(); 524 fst_pairs->clear(); 525 for (Label i = 0; i < fst_array_.size(); ++i) { 526 Label label = nonterminal_array_[i]; 527 const Fst<Arc> *fst = fst_array_[i]; 528 if (!fst) 529 continue; 530 fst_pairs->push_back(make_pair(label, fst)); 531 } 532 } 533 534 template <class Arc> 535 void ReplaceUtil<Arc>::GetMutableFstPairs( 536 vector<MutableFstPair> *mutable_fst_pairs) { 537 CheckMutableFsts(); 538 mutable_fst_pairs->clear(); 539 for (Label i = 0; i < mutable_fst_array_.size(); ++i) { 540 Label label = nonterminal_array_[i]; 541 MutableFst<Arc> *fst = mutable_fst_array_[i]; 542 if (!fst) 543 continue; 544 mutable_fst_pairs->push_back(make_pair(label, fst->Copy())); 545 } 546 } 547 548 } // namespace fst 549 550 #endif // FST_LIB_REPLACE_UTIL_H__ 551