1 // replace.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: johans (at) google.com (Johan Schalkwyk) 17 // 18 // \file 19 // Functions and classes for the recursive replacement of Fsts. 20 // 21 22 #ifndef FST_LIB_REPLACE_H__ 23 #define FST_LIB_REPLACE_H__ 24 25 #include <tr1/unordered_map> 26 using std::tr1::unordered_map; 27 using std::tr1::unordered_multimap; 28 #include <set> 29 #include <string> 30 #include <utility> 31 using std::pair; using std::make_pair; 32 #include <vector> 33 using std::vector; 34 35 #include <fst/cache.h> 36 #include <fst/expanded-fst.h> 37 #include <fst/fst.h> 38 #include <fst/matcher.h> 39 #include <fst/replace-util.h> 40 #include <fst/state-table.h> 41 #include <fst/test-properties.h> 42 43 namespace fst { 44 45 // 46 // REPLACE STATE TUPLES AND TABLES 47 // 48 // The replace state table has the form 49 // 50 // template <class A, class P> 51 // class ReplaceStateTable { 52 // public: 53 // typedef A Arc; 54 // typedef P PrefixId; 55 // typedef typename A::StateId StateId; 56 // typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; 57 // typedef typename A::Label Label; 58 // 59 // // Required constuctor 60 // ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples, 61 // Label root); 62 // 63 // // Required copy constructor that does not copy state 64 // ReplaceStateTable(const ReplaceStateTable<A,P> &table); 65 // 66 // // Lookup state ID by tuple. If it doesn't exist, then add it. 67 // StateId FindState(const StateTuple &tuple); 68 // 69 // // Lookup state tuple by ID. 70 // const StateTuple &Tuple(StateId id) const; 71 // }; 72 73 74 // \struct ReplaceStateTuple 75 // \brief Tuple of information that uniquely defines a state in replace 76 template <class S, class P> 77 struct ReplaceStateTuple { 78 typedef S StateId; 79 typedef P PrefixId; 80 81 ReplaceStateTuple() 82 : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {} 83 84 ReplaceStateTuple(PrefixId p, StateId f, StateId s) 85 : prefix_id(p), fst_id(f), fst_state(s) {} 86 87 PrefixId prefix_id; // index in prefix table 88 StateId fst_id; // current fst being walked 89 StateId fst_state; // current state in fst being walked, not to be 90 // confused with the state_id of the combined fst 91 }; 92 93 94 // Equality of replace state tuples. 95 template <class S, class P> 96 inline bool operator==(const ReplaceStateTuple<S, P>& x, 97 const ReplaceStateTuple<S, P>& y) { 98 return x.prefix_id == y.prefix_id && 99 x.fst_id == y.fst_id && 100 x.fst_state == y.fst_state; 101 } 102 103 104 // \class ReplaceRootSelector 105 // Functor returning true for tuples corresponding to states in the root FST 106 template <class S, class P> 107 class ReplaceRootSelector { 108 public: 109 bool operator()(const ReplaceStateTuple<S, P> &tuple) const { 110 return tuple.prefix_id == 0; 111 } 112 }; 113 114 115 // \class ReplaceFingerprint 116 // Fingerprint for general replace state tuples. 117 template <class S, class P> 118 class ReplaceFingerprint { 119 public: 120 ReplaceFingerprint(const vector<uint64> *size_array) 121 : cumulative_size_array_(size_array) {} 122 123 uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const { 124 return tuple.prefix_id * (cumulative_size_array_->back()) + 125 cumulative_size_array_->at(tuple.fst_id - 1) + 126 tuple.fst_state; 127 } 128 129 private: 130 const vector<uint64> *cumulative_size_array_; 131 }; 132 133 134 // \class ReplaceFstStateFingerprint 135 // Useful when the fst_state uniquely define the tuple. 136 template <class S, class P> 137 class ReplaceFstStateFingerprint { 138 public: 139 uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const { 140 return tuple.fst_state; 141 } 142 }; 143 144 145 // \class ReplaceHash 146 // A generic hash function for replace state tuples. 147 template <typename S, typename P> 148 class ReplaceHash { 149 public: 150 size_t operator()(const ReplaceStateTuple<S, P>& t) const { 151 return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1; 152 } 153 private: 154 static const size_t kPrime0; 155 static const size_t kPrime1; 156 }; 157 158 template <typename S, typename P> 159 const size_t ReplaceHash<S, P>::kPrime0 = 7853; 160 161 template <typename S, typename P> 162 const size_t ReplaceHash<S, P>::kPrime1 = 7867; 163 164 template <class A, class T> class ReplaceFstMatcher; 165 166 167 // \class VectorHashReplaceStateTable 168 // A two-level state table for replace. 169 // Warning: calls CountStates to compute the number of states of each 170 // component Fst. 171 template <class A, class P = ssize_t> 172 class VectorHashReplaceStateTable { 173 public: 174 typedef A Arc; 175 typedef typename A::StateId StateId; 176 typedef typename A::Label Label; 177 typedef P PrefixId; 178 typedef ReplaceStateTuple<StateId, P> StateTuple; 179 typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>, 180 ReplaceRootSelector<StateId, P>, 181 ReplaceFstStateFingerprint<StateId, P>, 182 ReplaceFingerprint<StateId, P> > StateTable; 183 184 VectorHashReplaceStateTable( 185 const vector<pair<Label, const Fst<A>*> > &fst_tuples, 186 Label root) : root_size_(0) { 187 cumulative_size_array_.push_back(0); 188 for (size_t i = 0; i < fst_tuples.size(); ++i) { 189 if (fst_tuples[i].first == root) { 190 root_size_ = CountStates(*(fst_tuples[i].second)); 191 cumulative_size_array_.push_back(cumulative_size_array_.back()); 192 } else { 193 cumulative_size_array_.push_back(cumulative_size_array_.back() + 194 CountStates(*(fst_tuples[i].second))); 195 } 196 } 197 state_table_ = new StateTable( 198 new ReplaceRootSelector<StateId, P>, 199 new ReplaceFstStateFingerprint<StateId, P>, 200 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), 201 root_size_, 202 root_size_ + cumulative_size_array_.back()); 203 } 204 205 VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table) 206 : root_size_(table.root_size_), 207 cumulative_size_array_(table.cumulative_size_array_) { 208 state_table_ = new StateTable( 209 new ReplaceRootSelector<StateId, P>, 210 new ReplaceFstStateFingerprint<StateId, P>, 211 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), 212 root_size_, 213 root_size_ + cumulative_size_array_.back()); 214 } 215 216 ~VectorHashReplaceStateTable() { 217 delete state_table_; 218 } 219 220 StateId FindState(const StateTuple &tuple) { 221 return state_table_->FindState(tuple); 222 } 223 224 const StateTuple &Tuple(StateId id) const { 225 return state_table_->Tuple(id); 226 } 227 228 private: 229 StateId root_size_; 230 vector<uint64> cumulative_size_array_; 231 StateTable *state_table_; 232 }; 233 234 235 // \class DefaultReplaceStateTable 236 // Default replace state table 237 template <class A, class P = ssize_t> 238 class DefaultReplaceStateTable : public CompactHashStateTable< 239 ReplaceStateTuple<typename A::StateId, P>, 240 ReplaceHash<typename A::StateId, P> > { 241 public: 242 typedef A Arc; 243 typedef typename A::StateId StateId; 244 typedef typename A::Label Label; 245 typedef P PrefixId; 246 typedef ReplaceStateTuple<StateId, P> StateTuple; 247 typedef CompactHashStateTable<StateTuple, 248 ReplaceHash<StateId, PrefixId> > StateTable; 249 250 using StateTable::FindState; 251 using StateTable::Tuple; 252 253 DefaultReplaceStateTable( 254 const vector<pair<Label, const Fst<A>*> > &fst_tuples, 255 Label root) {} 256 257 DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table) 258 : StateTable() {} 259 }; 260 261 // 262 // REPLACE FST CLASS 263 // 264 265 // By default ReplaceFst will copy the input label of the 'replace arc'. 266 // For acceptors we do not want this behaviour. Instead we need to 267 // create an epsilon arc when recursing into the appropriate Fst. 268 // The 'epsilon_on_replace' option can be used to toggle this behaviour. 269 template <class A, class T = DefaultReplaceStateTable<A> > 270 struct ReplaceFstOptions : CacheOptions { 271 int64 root; // root rule for expansion 272 bool epsilon_on_replace; 273 bool take_ownership; // take ownership of input Fst(s) 274 T* state_table; 275 276 ReplaceFstOptions(const CacheOptions &opts, int64 r) 277 : CacheOptions(opts), 278 root(r), 279 epsilon_on_replace(false), 280 take_ownership(false), 281 state_table(0) {} 282 explicit ReplaceFstOptions(int64 r) 283 : root(r), 284 epsilon_on_replace(false), 285 take_ownership(false), 286 state_table(0) {} 287 ReplaceFstOptions(int64 r, bool epsilon_replace_arc) 288 : root(r), 289 epsilon_on_replace(epsilon_replace_arc), 290 take_ownership(false), 291 state_table(0) {} 292 ReplaceFstOptions() 293 : root(kNoLabel), 294 epsilon_on_replace(false), 295 take_ownership(false), 296 state_table(0) {} 297 }; 298 299 300 // \class ReplaceFstImpl 301 // \brief Implementation class for replace class Fst 302 // 303 // The replace implementation class supports a dynamic 304 // expansion of a recursive transition network represented as Fst 305 // with dynamic replacable arcs. 306 // 307 template <class A, class T> 308 class ReplaceFstImpl : public CacheImpl<A> { 309 friend class ReplaceFstMatcher<A, T>; 310 311 public: 312 using FstImpl<A>::SetType; 313 using FstImpl<A>::SetProperties; 314 using FstImpl<A>::WriteHeader; 315 using FstImpl<A>::SetInputSymbols; 316 using FstImpl<A>::SetOutputSymbols; 317 using FstImpl<A>::InputSymbols; 318 using FstImpl<A>::OutputSymbols; 319 320 using CacheImpl<A>::PushArc; 321 using CacheImpl<A>::HasArcs; 322 using CacheImpl<A>::HasFinal; 323 using CacheImpl<A>::HasStart; 324 using CacheImpl<A>::SetArcs; 325 using CacheImpl<A>::SetFinal; 326 using CacheImpl<A>::SetStart; 327 328 typedef typename A::Label Label; 329 typedef typename A::Weight Weight; 330 typedef typename A::StateId StateId; 331 typedef CacheState<A> State; 332 typedef A Arc; 333 typedef unordered_map<Label, Label> NonTerminalHash; 334 335 typedef T StateTable; 336 typedef typename T::PrefixId PrefixId; 337 typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; 338 339 // constructor for replace class implementation. 340 // \param fst_tuples array of label/fst tuples, one for each non-terminal 341 ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples, 342 const ReplaceFstOptions<A, T> &opts) 343 : CacheImpl<A>(opts), 344 epsilon_on_replace_(opts.epsilon_on_replace), 345 state_table_(opts.state_table ? opts.state_table : 346 new StateTable(fst_tuples, opts.root)) { 347 348 SetType("replace"); 349 350 if (fst_tuples.size() > 0) { 351 SetInputSymbols(fst_tuples[0].second->InputSymbols()); 352 SetOutputSymbols(fst_tuples[0].second->OutputSymbols()); 353 } 354 355 bool all_negative = true; // all nonterminals are negative? 356 bool dense_range = true; // all nonterminals are positive 357 // and form a dense range containing 1? 358 for (size_t i = 0; i < fst_tuples.size(); ++i) { 359 Label nonterminal = fst_tuples[i].first; 360 if (nonterminal >= 0) 361 all_negative = false; 362 if (nonterminal > fst_tuples.size() || nonterminal <= 0) 363 dense_range = false; 364 } 365 366 vector<uint64> inprops; 367 bool all_ilabel_sorted = true; 368 bool all_olabel_sorted = true; 369 bool all_non_empty = true; 370 fst_array_.push_back(0); 371 for (size_t i = 0; i < fst_tuples.size(); ++i) { 372 Label label = fst_tuples[i].first; 373 const Fst<A> *fst = fst_tuples[i].second; 374 nonterminal_hash_[label] = fst_array_.size(); 375 nonterminal_set_.insert(label); 376 fst_array_.push_back(opts.take_ownership ? fst : fst->Copy()); 377 if (fst->Start() == kNoStateId) 378 all_non_empty = false; 379 if(!fst->Properties(kILabelSorted, false)) 380 all_ilabel_sorted = false; 381 if(!fst->Properties(kOLabelSorted, false)) 382 all_olabel_sorted = false; 383 inprops.push_back(fst->Properties(kCopyProperties, false)); 384 if (i) { 385 if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) { 386 FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i 387 << " does not match input symbols of base Fst (0'th fst)"; 388 SetProperties(kError, kError); 389 } 390 if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) { 391 FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i 392 << " does not match output symbols of base Fst " 393 << "(0'th fst)"; 394 SetProperties(kError, kError); 395 } 396 } 397 } 398 Label nonterminal = nonterminal_hash_[opts.root]; 399 if ((nonterminal == 0) && (fst_array_.size() > 1)) { 400 FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '" 401 << opts.root << "' in the input tuple vector"; 402 SetProperties(kError, kError); 403 } 404 root_ = (nonterminal > 0) ? nonterminal : 1; 405 406 SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_, 407 all_non_empty)); 408 // We assume that all terminals are positive. The resulting 409 // ReplaceFst is known to be kILabelSorted when all sub-FSTs are 410 // kILabelSorted and one of the 3 following conditions is satisfied: 411 // 1. 'epsilon_on_replace' is false, or 412 // 2. all non-terminals are negative, or 413 // 3. all non-terninals are positive and form a dense range containing 1. 414 if (all_ilabel_sorted && 415 (!epsilon_on_replace_ || all_negative || dense_range)) 416 SetProperties(kILabelSorted, kILabelSorted); 417 // Similarly, the resulting ReplaceFst is known to be 418 // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of 419 // the 2 following conditions is satisfied: 420 // 1. all non-terminals are negative, or 421 // 2. all non-terninals are positive and form a dense range containing 1. 422 if (all_olabel_sorted && (all_negative || dense_range)) 423 SetProperties(kOLabelSorted, kOLabelSorted); 424 425 // Enable optional caching as long as sorted and all non empty. 426 if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty) 427 always_cache_ = false; 428 else 429 always_cache_ = true; 430 VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = " 431 << (always_cache_ ? "true" : "false"); 432 } 433 434 ReplaceFstImpl(const ReplaceFstImpl& impl) 435 : CacheImpl<A>(impl), 436 epsilon_on_replace_(impl.epsilon_on_replace_), 437 always_cache_(impl.always_cache_), 438 state_table_(new StateTable(*(impl.state_table_))), 439 nonterminal_set_(impl.nonterminal_set_), 440 nonterminal_hash_(impl.nonterminal_hash_), 441 root_(impl.root_) { 442 SetType("replace"); 443 SetProperties(impl.Properties(), kCopyProperties); 444 SetInputSymbols(impl.InputSymbols()); 445 SetOutputSymbols(impl.OutputSymbols()); 446 fst_array_.reserve(impl.fst_array_.size()); 447 fst_array_.push_back(0); 448 for (size_t i = 1; i < impl.fst_array_.size(); ++i) { 449 fst_array_.push_back(impl.fst_array_[i]->Copy(true)); 450 } 451 } 452 453 ~ReplaceFstImpl() { 454 VLOG(2) << "~ReplaceFstImpl: gc = " 455 << (CacheImpl<A>::GetCacheGc() ? "true" : "false") 456 << ", gc_size = " << CacheImpl<A>::GetCacheSize() 457 << ", gc_limit = " << CacheImpl<A>::GetCacheLimit(); 458 459 delete state_table_; 460 for (size_t i = 1; i < fst_array_.size(); ++i) { 461 delete fst_array_[i]; 462 } 463 } 464 465 // Computes the dependency graph of the replace class and returns 466 // true if the dependencies are cyclic. Cyclic dependencies will result 467 // in an un-expandable replace fst. 468 bool CyclicDependencies() const { 469 ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_); 470 return replace_util.CyclicDependencies(); 471 } 472 473 // Return or compute start state of replace fst 474 StateId Start() { 475 if (!HasStart()) { 476 if (fst_array_.size() == 1) { // no fsts defined for replace 477 SetStart(kNoStateId); 478 return kNoStateId; 479 } else { 480 const Fst<A>* fst = fst_array_[root_]; 481 StateId fst_start = fst->Start(); 482 if (fst_start == kNoStateId) // root Fst is empty 483 return kNoStateId; 484 485 PrefixId prefix = GetPrefixId(StackPrefix()); 486 StateId start = state_table_->FindState( 487 StateTuple(prefix, root_, fst_start)); 488 SetStart(start); 489 return start; 490 } 491 } else { 492 return CacheImpl<A>::Start(); 493 } 494 } 495 496 // return final weight of state (kInfWeight means state is not final) 497 Weight Final(StateId s) { 498 if (!HasFinal(s)) { 499 const StateTuple& tuple = state_table_->Tuple(s); 500 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; 501 const Fst<A>* fst = fst_array_[tuple.fst_id]; 502 StateId fst_state = tuple.fst_state; 503 504 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0) 505 SetFinal(s, fst->Final(fst_state)); 506 else 507 SetFinal(s, Weight::Zero()); 508 } 509 return CacheImpl<A>::Final(s); 510 } 511 512 size_t NumArcs(StateId s) { 513 if (HasArcs(s)) { // If state cached, use the cached value. 514 return CacheImpl<A>::NumArcs(s); 515 } else if (always_cache_) { // If always caching, expand and cache state. 516 Expand(s); 517 return CacheImpl<A>::NumArcs(s); 518 } else { // Otherwise compute the number of arcs without expanding. 519 StateTuple tuple = state_table_->Tuple(s); 520 if (tuple.fst_state == kNoStateId) 521 return 0; 522 523 const Fst<A>* fst = fst_array_[tuple.fst_id]; 524 size_t num_arcs = fst->NumArcs(tuple.fst_state); 525 if (ComputeFinalArc(tuple, 0)) 526 num_arcs++; 527 528 return num_arcs; 529 } 530 } 531 532 // Returns whether a given label is a non terminal 533 bool IsNonTerminal(Label l) const { 534 // TODO(allauzen): be smarter and take advantage of 535 // all_dense or all_negative. 536 // Use also in ComputeArc, this would require changes to replace 537 // so that recursing into an empty fst lead to a non co-accessible 538 // state instead of deleting the arc as done currently. 539 // Current use correct, since i/olabel sorted iff all_non_empty. 540 typename NonTerminalHash::const_iterator it = 541 nonterminal_hash_.find(l); 542 return it != nonterminal_hash_.end(); 543 } 544 545 size_t NumInputEpsilons(StateId s) { 546 if (HasArcs(s)) { 547 // If state cached, use the cached value. 548 return CacheImpl<A>::NumInputEpsilons(s); 549 } else if (always_cache_ || !Properties(kILabelSorted)) { 550 // If always caching or if the number of input epsilons is too expensive 551 // to compute without caching (i.e. not ilabel sorted), 552 // then expand and cache state. 553 Expand(s); 554 return CacheImpl<A>::NumInputEpsilons(s); 555 } else { 556 // Otherwise, compute the number of input epsilons without caching. 557 StateTuple tuple = state_table_->Tuple(s); 558 if (tuple.fst_state == kNoStateId) 559 return 0; 560 const Fst<A>* fst = fst_array_[tuple.fst_id]; 561 size_t num = 0; 562 if (!epsilon_on_replace_) { 563 // If epsilon_on_replace is false, all input epsilon arcs 564 // are also input epsilons arcs in the underlying machine. 565 fst->NumInputEpsilons(tuple.fst_state); 566 } else { 567 // Otherwise, one need to consider that all non-terminal arcs 568 // in the underlying machine also become input epsilon arc. 569 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); 570 for (; !aiter.Done() && 571 ((aiter.Value().ilabel == 0) || 572 IsNonTerminal(aiter.Value().olabel)); 573 aiter.Next()) 574 ++num; 575 } 576 if (ComputeFinalArc(tuple, 0)) 577 num++; 578 return num; 579 } 580 } 581 582 size_t NumOutputEpsilons(StateId s) { 583 if (HasArcs(s)) { 584 // If state cached, use the cached value. 585 return CacheImpl<A>::NumOutputEpsilons(s); 586 } else if(always_cache_ || !Properties(kOLabelSorted)) { 587 // If always caching or if the number of output epsilons is too expensive 588 // to compute without caching (i.e. not olabel sorted), 589 // then expand and cache state. 590 Expand(s); 591 return CacheImpl<A>::NumOutputEpsilons(s); 592 } else { 593 // Otherwise, compute the number of output epsilons without caching. 594 StateTuple tuple = state_table_->Tuple(s); 595 if (tuple.fst_state == kNoStateId) 596 return 0; 597 const Fst<A>* fst = fst_array_[tuple.fst_id]; 598 size_t num = 0; 599 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); 600 for (; !aiter.Done() && 601 ((aiter.Value().olabel == 0) || 602 IsNonTerminal(aiter.Value().olabel)); 603 aiter.Next()) 604 ++num; 605 if (ComputeFinalArc(tuple, 0)) 606 num++; 607 return num; 608 } 609 } 610 611 uint64 Properties() const { return Properties(kFstProperties); } 612 613 // Set error if found; return FST impl properties. 614 uint64 Properties(uint64 mask) const { 615 if (mask & kError) { 616 for (size_t i = 1; i < fst_array_.size(); ++i) { 617 if (fst_array_[i]->Properties(kError, false)) 618 SetProperties(kError, kError); 619 } 620 } 621 return FstImpl<Arc>::Properties(mask); 622 } 623 624 // return the base arc iterator, if arcs have not been computed yet, 625 // extend/recurse for new arcs. 626 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 627 if (!HasArcs(s)) 628 Expand(s); 629 CacheImpl<A>::InitArcIterator(s, data); 630 // TODO(allauzen): Set behaviour of generic iterator 631 // Warning: ArcIterator<ReplaceFst<A> >::InitCache() 632 // relies on current behaviour. 633 } 634 635 636 // Extend current state (walk arcs one level deep) 637 void Expand(StateId s) { 638 StateTuple tuple = state_table_->Tuple(s); 639 640 // If local fst is empty 641 if (tuple.fst_state == kNoStateId) { 642 SetArcs(s); 643 return; 644 } 645 646 ArcIterator< Fst<A> > aiter( 647 *(fst_array_[tuple.fst_id]), tuple.fst_state); 648 Arc arc; 649 650 // Create a final arc when needed 651 if (ComputeFinalArc(tuple, &arc)) 652 PushArc(s, arc); 653 654 // Expand all arcs leaving the state 655 for (;!aiter.Done(); aiter.Next()) { 656 if (ComputeArc(tuple, aiter.Value(), &arc)) 657 PushArc(s, arc); 658 } 659 660 SetArcs(s); 661 } 662 663 void Expand(StateId s, const StateTuple &tuple, 664 const ArcIteratorData<A> &data) { 665 // If local fst is empty 666 if (tuple.fst_state == kNoStateId) { 667 SetArcs(s); 668 return; 669 } 670 671 ArcIterator< Fst<A> > aiter(data); 672 Arc arc; 673 674 // Create a final arc when needed 675 if (ComputeFinalArc(tuple, &arc)) 676 AddArc(s, arc); 677 678 // Expand all arcs leaving the state 679 for (; !aiter.Done(); aiter.Next()) { 680 if (ComputeArc(tuple, aiter.Value(), &arc)) 681 AddArc(s, arc); 682 } 683 684 SetArcs(s); 685 } 686 687 // If arcp == 0, only returns if a final arc is required, does not 688 // actually compute it. 689 bool ComputeFinalArc(const StateTuple &tuple, A* arcp, 690 uint32 flags = kArcValueFlags) { 691 const Fst<A>* fst = fst_array_[tuple.fst_id]; 692 StateId fst_state = tuple.fst_state; 693 if (fst_state == kNoStateId) 694 return false; 695 696 // if state is final, pop up stack 697 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; 698 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) { 699 if (arcp) { 700 arcp->ilabel = 0; 701 arcp->olabel = 0; 702 if (flags & kArcNextStateValue) { 703 PrefixId prefix_id = PopPrefix(stack); 704 const PrefixTuple& top = stack.Top(); 705 arcp->nextstate = state_table_->FindState( 706 StateTuple(prefix_id, top.fst_id, top.nextstate)); 707 } 708 if (flags & kArcWeightValue) 709 arcp->weight = fst->Final(fst_state); 710 } 711 return true; 712 } else { 713 return false; 714 } 715 } 716 717 // Compute the arc in the replace fst corresponding to a given 718 // in the underlying machine. Returns false if the underlying arc 719 // corresponds to no arc in the replace. 720 bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp, 721 uint32 flags = kArcValueFlags) { 722 if (!epsilon_on_replace_ && 723 (flags == (flags & (kArcILabelValue | kArcWeightValue)))) { 724 *arcp = arc; 725 return true; 726 } 727 728 if (arc.olabel == 0) { // expand local fst 729 StateId nextstate = flags & kArcNextStateValue 730 ? state_table_->FindState( 731 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) 732 : kNoStateId; 733 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); 734 } else { 735 // check for non terminal 736 typename NonTerminalHash::const_iterator it = 737 nonterminal_hash_.find(arc.olabel); 738 if (it != nonterminal_hash_.end()) { // recurse into non terminal 739 Label nonterminal = it->second; 740 const Fst<A>* nt_fst = fst_array_[nonterminal]; 741 PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id], 742 tuple.fst_id, arc.nextstate); 743 744 // if start state is valid replace, else arc is implicitly 745 // deleted 746 StateId nt_start = nt_fst->Start(); 747 if (nt_start != kNoStateId) { 748 StateId nt_nextstate = flags & kArcNextStateValue 749 ? state_table_->FindState( 750 StateTuple(nt_prefix, nonterminal, nt_start)) 751 : kNoStateId; 752 Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel; 753 *arcp = A(ilabel, 0, arc.weight, nt_nextstate); 754 } else { 755 return false; 756 } 757 } else { 758 StateId nextstate = flags & kArcNextStateValue 759 ? state_table_->FindState( 760 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) 761 : kNoStateId; 762 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); 763 } 764 } 765 return true; 766 } 767 768 // Returns the arc iterator flags supported by this Fst. 769 uint32 ArcIteratorFlags() const { 770 uint32 flags = kArcValueFlags; 771 if (!always_cache_) 772 flags |= kArcNoCache; 773 return flags; 774 } 775 776 T* GetStateTable() const { 777 return state_table_; 778 } 779 780 const Fst<A>* GetFst(Label fst_id) const { 781 return fst_array_[fst_id]; 782 } 783 784 bool EpsilonOnReplace() const { return epsilon_on_replace_; } 785 786 // private helper classes 787 private: 788 static const size_t kPrime0; 789 790 // \class PrefixTuple 791 // \brief Tuple of fst_id and destination state (entry in stack prefix) 792 struct PrefixTuple { 793 PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {} 794 795 Label fst_id; 796 StateId nextstate; 797 }; 798 799 // \class StackPrefix 800 // \brief Container for stack prefix. 801 class StackPrefix { 802 public: 803 StackPrefix() {} 804 805 // copy constructor 806 StackPrefix(const StackPrefix& x) : 807 prefix_(x.prefix_) { 808 } 809 810 void Push(StateId fst_id, StateId nextstate) { 811 prefix_.push_back(PrefixTuple(fst_id, nextstate)); 812 } 813 814 void Pop() { 815 prefix_.pop_back(); 816 } 817 818 const PrefixTuple& Top() const { 819 return prefix_[prefix_.size()-1]; 820 } 821 822 size_t Depth() const { 823 return prefix_.size(); 824 } 825 826 public: 827 vector<PrefixTuple> prefix_; 828 }; 829 830 831 // \class StackPrefixEqual 832 // \brief Compare two stack prefix classes for equality 833 class StackPrefixEqual { 834 public: 835 bool operator()(const StackPrefix& x, const StackPrefix& y) const { 836 if (x.prefix_.size() != y.prefix_.size()) return false; 837 for (size_t i = 0; i < x.prefix_.size(); ++i) { 838 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || 839 x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false; 840 } 841 return true; 842 } 843 }; 844 845 // 846 // \class StackPrefixKey 847 // \brief Hash function for stack prefix to prefix id 848 class StackPrefixKey { 849 public: 850 size_t operator()(const StackPrefix& x) const { 851 size_t sum = 0; 852 for (size_t i = 0; i < x.prefix_.size(); ++i) { 853 sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0; 854 } 855 return sum; 856 } 857 }; 858 859 typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual> 860 StackPrefixHash; 861 862 // private methods 863 private: 864 // hash stack prefix (return unique index into stackprefix array) 865 PrefixId GetPrefixId(const StackPrefix& prefix) { 866 typename StackPrefixHash::iterator it = prefix_hash_.find(prefix); 867 if (it == prefix_hash_.end()) { 868 PrefixId prefix_id = stackprefix_array_.size(); 869 stackprefix_array_.push_back(prefix); 870 prefix_hash_[prefix] = prefix_id; 871 return prefix_id; 872 } else { 873 return it->second; 874 } 875 } 876 877 // prefix id after a stack pop 878 PrefixId PopPrefix(StackPrefix prefix) { 879 prefix.Pop(); 880 return GetPrefixId(prefix); 881 } 882 883 // prefix id after a stack push 884 PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { 885 prefix.Push(fst_id, nextstate); 886 return GetPrefixId(prefix); 887 } 888 889 890 // private data 891 private: 892 // runtime options 893 bool epsilon_on_replace_; 894 bool always_cache_; // Optionally caching arc iterator disabled when true 895 896 // state table 897 StateTable *state_table_; 898 899 // cross index of unique stack prefix 900 // could potentially have one copy of prefix array 901 StackPrefixHash prefix_hash_; 902 vector<StackPrefix> stackprefix_array_; 903 904 set<Label> nonterminal_set_; 905 NonTerminalHash nonterminal_hash_; 906 vector<const Fst<A>*> fst_array_; 907 Label root_; 908 909 void operator=(const ReplaceFstImpl<A, T> &); // disallow 910 }; 911 912 913 template <class A, class T> 914 const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853; 915 916 // 917 // \class ReplaceFst 918 // \brief Recursivively replaces arcs in the root Fst with other Fsts. 919 // This version is a delayed Fst. 920 // 921 // ReplaceFst supports dynamic replacement of arcs in one Fst with 922 // another Fst. This replacement is recursive. ReplaceFst can be used 923 // to support a variety of delayed constructions such as recursive 924 // transition networks, union, or closure. It is constructed with an 925 // array of Fst(s). One Fst represents the root (or topology) 926 // machine. The root Fst refers to other Fsts by recursively replacing 927 // arcs labeled as non-terminals with the matching non-terminal 928 // Fst. Currently the ReplaceFst uses the output symbols of the arcs 929 // to determine whether the arc is a non-terminal arc or not. A 930 // non-terminal can be any label that is not a non-zero terminal label 931 // in the output alphabet. 932 // 933 // Note that the constructor uses a vector of pair<>. These correspond 934 // to the tuple of non-terminal Label and corresponding Fst. For example 935 // to implement the closure operation we need 2 Fsts. The first root 936 // Fst is a single Arc on the start State that self loops, it references 937 // the particular machine for which we are performing the closure operation. 938 // 939 // The ReplaceFst class supports an optionally caching arc iterator: 940 // ArcIterator< ReplaceFst<A> > 941 // The ReplaceFst need to be built such that it is known to be ilabel 942 // or olabel sorted (see usage below). 943 // 944 // Observe that Matcher<Fst<A> > will use the optionally caching arc 945 // iterator when available (Fst is ilabel sorted and matching on the 946 // input, or Fst is olabel sorted and matching on the output). 947 // In order to obtain the most efficient behaviour, it is recommended 948 // to set 'epsilon_on_replace' to false (this means constructing acceptors 949 // as transducers with epsilons on the input side of nonterminal arcs) 950 // and matching on the input side. 951 // 952 // This class attaches interface to implementation and handles 953 // reference counting, delegating most methods to ImplToFst. 954 template <class A, class T = DefaultReplaceStateTable<A> > 955 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > { 956 public: 957 friend class ArcIterator< ReplaceFst<A, T> >; 958 friend class StateIterator< ReplaceFst<A, T> >; 959 friend class ReplaceFstMatcher<A, T>; 960 961 typedef A Arc; 962 typedef typename A::Label Label; 963 typedef typename A::Weight Weight; 964 typedef typename A::StateId StateId; 965 typedef CacheState<A> State; 966 typedef ReplaceFstImpl<A, T> Impl; 967 968 using ImplToFst<Impl>::Properties; 969 970 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, 971 Label root) 972 : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {} 973 974 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, 975 const ReplaceFstOptions<A, T> &opts) 976 : ImplToFst<Impl>(new Impl(fst_array, opts)) {} 977 978 // See Fst<>::Copy() for doc. 979 ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false) 980 : ImplToFst<Impl>(fst, safe) {} 981 982 // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc. 983 virtual ReplaceFst<A, T> *Copy(bool safe = false) const { 984 return new ReplaceFst<A, T>(*this, safe); 985 } 986 987 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 988 989 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 990 GetImpl()->InitArcIterator(s, data); 991 } 992 993 virtual MatcherBase<A> *InitMatcher(MatchType match_type) const { 994 if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) && 995 ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) || 996 (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) { 997 return new ReplaceFstMatcher<A, T>(*this, match_type); 998 } 999 else { 1000 VLOG(2) << "Not using replace matcher"; 1001 return 0; 1002 } 1003 } 1004 1005 bool CyclicDependencies() const { 1006 return GetImpl()->CyclicDependencies(); 1007 } 1008 1009 private: 1010 // Makes visible to friends. 1011 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 1012 1013 void operator=(const ReplaceFst<A> &fst); // disallow 1014 }; 1015 1016 1017 // Specialization for ReplaceFst. 1018 template<class A, class T> 1019 class StateIterator< ReplaceFst<A, T> > 1020 : public CacheStateIterator< ReplaceFst<A, T> > { 1021 public: 1022 explicit StateIterator(const ReplaceFst<A, T> &fst) 1023 : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {} 1024 1025 private: 1026 DISALLOW_COPY_AND_ASSIGN(StateIterator); 1027 }; 1028 1029 1030 // Specialization for ReplaceFst. 1031 // Implements optional caching. It can be used as follows: 1032 // 1033 // ReplaceFst<A> replace; 1034 // ArcIterator< ReplaceFst<A> > aiter(replace, s); 1035 // // Note: ArcIterator< Fst<A> > is always a caching arc iterator. 1036 // aiter.SetFlags(kArcNoCache, kArcNoCache); 1037 // // Use the arc iterator, no arc will be cached, no state will be expanded. 1038 // // The varied 'kArcValueFlags' can be used to decide which part 1039 // // of arc values needs to be computed. 1040 // aiter.SetFlags(kArcILabelValue, kArcValueFlags); 1041 // // Only want the ilabel for this arc 1042 // aiter.Value(); // Does not compute the destination state. 1043 // aiter.Next(); 1044 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue); 1045 // // Want both ilabel and nextstate for that arc 1046 // aiter.Value(); // Does compute the destination state and inserts it 1047 // // in the replace state table. 1048 // // No Arc has been cached at that point. 1049 // 1050 template <class A, class T> 1051 class ArcIterator< ReplaceFst<A, T> > { 1052 public: 1053 typedef A Arc; 1054 typedef typename A::StateId StateId; 1055 1056 ArcIterator(const ReplaceFst<A, T> &fst, StateId s) 1057 : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0), 1058 data_flags_(0), final_flags_(0) { 1059 cache_data_.ref_count = 0; 1060 local_data_.ref_count = 0; 1061 1062 // If FST does not support optional caching, force caching. 1063 if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) && 1064 !(fst_.GetImpl()->HasArcs(state_))) 1065 fst_.GetImpl()->Expand(state_); 1066 1067 // If state is already cached, use cached arcs array. 1068 if (fst_.GetImpl()->HasArcs(state_)) { 1069 (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_, 1070 &cache_data_); 1071 num_arcs_ = cache_data_.narcs; 1072 arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs. 1073 data_flags_ = kArcValueFlags; // All the arc member values are valid. 1074 } else { // Otherwise delay decision until Value() is called. 1075 tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_); 1076 if (tuple_.fst_state == kNoStateId) { 1077 num_arcs_ = 0; 1078 } else { 1079 // The decision to cache or not to cache has been defered 1080 // until Value() or SetFlags() is called. However, the arc 1081 // iterator is set up now to be ready for non-caching in order 1082 // to keep the Value() method simple and efficient. 1083 const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id); 1084 fst->InitArcIterator(tuple_.fst_state, &local_data_); 1085 // 'arcs_' is a pointer to the arcs in the underlying machine. 1086 arcs_ = local_data_.arcs; 1087 // Compute the final arc (but not its destination state) 1088 // if a final arc is required. 1089 bool has_final_arc = fst_.GetImpl()->ComputeFinalArc( 1090 tuple_, 1091 &final_arc_, 1092 kArcValueFlags & ~kArcNextStateValue); 1093 // Set the arc value flags that hold for 'final_arc_'. 1094 final_flags_ = kArcValueFlags & ~kArcNextStateValue; 1095 // Compute the number of arcs. 1096 num_arcs_ = local_data_.narcs; 1097 if (has_final_arc) 1098 ++num_arcs_; 1099 // Set the offset between the underlying arc positions and 1100 // the positions in the arc iterator. 1101 offset_ = num_arcs_ - local_data_.narcs; 1102 // Defers the decision to cache or not until Value() or 1103 // SetFlags() is called. 1104 data_flags_ = 0; 1105 } 1106 } 1107 } 1108 1109 ~ArcIterator() { 1110 if (cache_data_.ref_count) 1111 --(*cache_data_.ref_count); 1112 if (local_data_.ref_count) 1113 --(*local_data_.ref_count); 1114 } 1115 1116 void ExpandAndCache() const { 1117 // TODO(allauzen): revisit this 1118 // fst_.GetImpl()->Expand(state_, tuple_, local_data_); 1119 // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_, 1120 // &cache_data_); 1121 // 1122 fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state. 1123 arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs. 1124 data_flags_ = kArcValueFlags; // All the arc member values are valid. 1125 offset_ = 0; // No offset 1126 1127 } 1128 1129 void Init() { 1130 if (flags_ & kArcNoCache) { // If caching is disabled 1131 // 'arcs_' is a pointer to the arcs in the underlying machine. 1132 arcs_ = local_data_.arcs; 1133 // Set the arcs value flags that hold for 'arcs_'. 1134 data_flags_ = kArcWeightValue; 1135 if (!fst_.GetImpl()->EpsilonOnReplace()) 1136 data_flags_ |= kArcILabelValue; 1137 // Set the offset between the underlying arc positions and 1138 // the positions in the arc iterator. 1139 offset_ = num_arcs_ - local_data_.narcs; 1140 } else { // Otherwise, expand and cache 1141 ExpandAndCache(); 1142 } 1143 } 1144 1145 bool Done() const { return pos_ >= num_arcs_; } 1146 1147 const A& Value() const { 1148 // If 'data_flags_' was set to 0, non-caching was not requested 1149 if (!data_flags_) { 1150 // TODO(allauzen): revisit this. 1151 if (flags_ & kArcNoCache) { 1152 // Should never happen. 1153 FSTERROR() << "ReplaceFst: inconsistent arc iterator flags"; 1154 } 1155 ExpandAndCache(); // Expand and cache. 1156 } 1157 1158 if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc. 1159 const A& arc = arcs_[pos_ - offset_]; 1160 if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) { 1161 // If the value flags for 'arc' match the recquired value flags 1162 // then return 'arc'. 1163 return arc; 1164 } else { 1165 // Otherwise, compute the corresponding arc on-the-fly. 1166 fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags); 1167 return arc_; 1168 } 1169 } else { // The requested arc is the 'final' arc. 1170 if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) { 1171 // If the arc value flags that hold for the final arc 1172 // do not match the requested value flags, then 1173 // 'final_arc_' needs to be updated. 1174 fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_, 1175 flags_ & kArcValueFlags); 1176 final_flags_ = flags_ & kArcValueFlags; 1177 } 1178 return final_arc_; 1179 } 1180 } 1181 1182 void Next() { ++pos_; } 1183 1184 size_t Position() const { return pos_; } 1185 1186 void Reset() { pos_ = 0; } 1187 1188 void Seek(size_t pos) { pos_ = pos; } 1189 1190 uint32 Flags() const { return flags_; } 1191 1192 void SetFlags(uint32 f, uint32 mask) { 1193 // Update the flags taking into account what flags are supported 1194 // by the Fst. 1195 flags_ &= ~mask; 1196 flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags()); 1197 // If non-caching is not requested (and caching has not already 1198 // been performed), then flush 'data_flags_' to request caching 1199 // during the next call to Value(). 1200 if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) { 1201 if (!fst_.GetImpl()->HasArcs(state_)) 1202 data_flags_ = 0; 1203 } 1204 // If 'data_flags_' has been flushed but non-caching is requested 1205 // before calling Value(), then set up the iterator for non-caching. 1206 if ((f & kArcNoCache) && (!data_flags_)) 1207 Init(); 1208 } 1209 1210 private: 1211 const ReplaceFst<A, T> &fst_; // Reference to the FST 1212 StateId state_; // State in the FST 1213 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ 1214 1215 ssize_t pos_; // Current position 1216 mutable ssize_t offset_; // Offset between position in iterator and in arcs_ 1217 ssize_t num_arcs_; // Number of arcs at state_ 1218 uint32 flags_; // Behavorial flags for the arc iterator 1219 mutable Arc arc_; // Memory to temporarily store computed arcs 1220 1221 mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache 1222 mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst 1223 1224 mutable const A* arcs_; // Array of arcs 1225 mutable uint32 data_flags_; // Arc value flags valid for data in arcs_ 1226 mutable Arc final_arc_; // Final arc (when required) 1227 mutable uint32 final_flags_; // Arc value flags valid for final_arc_ 1228 1229 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 1230 }; 1231 1232 1233 template <class A, class T> 1234 class ReplaceFstMatcher : public MatcherBase<A> { 1235 public: 1236 typedef A Arc; 1237 typedef typename A::StateId StateId; 1238 typedef typename A::Label Label; 1239 typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher; 1240 1241 ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type) 1242 : fst_(fst), 1243 impl_(fst_.GetImpl()), 1244 s_(fst::kNoStateId), 1245 match_type_(match_type), 1246 current_loop_(false), 1247 final_arc_(false), 1248 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { 1249 if (match_type_ == fst::MATCH_OUTPUT) 1250 swap(loop_.ilabel, loop_.olabel); 1251 InitMatchers(); 1252 } 1253 1254 ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false) 1255 : fst_(matcher.fst_), 1256 impl_(fst_.GetImpl()), 1257 s_(fst::kNoStateId), 1258 match_type_(matcher.match_type_), 1259 current_loop_(false), 1260 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { 1261 if (match_type_ == fst::MATCH_OUTPUT) 1262 swap(loop_.ilabel, loop_.olabel); 1263 InitMatchers(); 1264 } 1265 1266 // Create a local matcher for each component Fst of replace. 1267 // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher 1268 // is used to match each non-terminal arc, since these non-terminal 1269 // turn into epsilons on recursion. 1270 void InitMatchers() { 1271 const vector<const Fst<A>*>& fst_array = impl_->fst_array_; 1272 matcher_.resize(fst_array.size(), 0); 1273 for (size_t i = 0; i < fst_array.size(); ++i) { 1274 if (fst_array[i]) { 1275 matcher_[i] = 1276 new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList); 1277 1278 typename set<Label>::iterator it = impl_->nonterminal_set_.begin(); 1279 for (; it != impl_->nonterminal_set_.end(); ++it) { 1280 matcher_[i]->AddMultiEpsLabel(*it); 1281 } 1282 } 1283 } 1284 } 1285 1286 virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const { 1287 return new ReplaceFstMatcher<A, T>(*this, safe); 1288 } 1289 1290 virtual ~ReplaceFstMatcher() { 1291 for (size_t i = 0; i < matcher_.size(); ++i) 1292 delete matcher_[i]; 1293 } 1294 1295 virtual MatchType Type(bool test) const { 1296 if (match_type_ == MATCH_NONE) 1297 return match_type_; 1298 1299 uint64 true_prop = match_type_ == MATCH_INPUT ? 1300 kILabelSorted : kOLabelSorted; 1301 uint64 false_prop = match_type_ == MATCH_INPUT ? 1302 kNotILabelSorted : kNotOLabelSorted; 1303 uint64 props = fst_.Properties(true_prop | false_prop, test); 1304 1305 if (props & true_prop) 1306 return match_type_; 1307 else if (props & false_prop) 1308 return MATCH_NONE; 1309 else 1310 return MATCH_UNKNOWN; 1311 } 1312 1313 virtual const Fst<A> &GetFst() const { 1314 return fst_; 1315 } 1316 1317 virtual uint64 Properties(uint64 props) const { 1318 return props; 1319 } 1320 1321 private: 1322 // Set the sate from which our matching happens. 1323 virtual void SetState_(StateId s) { 1324 if (s_ == s) return; 1325 1326 s_ = s; 1327 tuple_ = impl_->GetStateTable()->Tuple(s_); 1328 if (tuple_.fst_state == kNoStateId) { 1329 done_ = true; 1330 return; 1331 } 1332 // Get current matcher. Used for non epsilon matching 1333 current_matcher_ = matcher_[tuple_.fst_id]; 1334 current_matcher_->SetState(tuple_.fst_state); 1335 loop_.nextstate = s_; 1336 1337 final_arc_ = false; 1338 } 1339 1340 // Search for label, from previous set state. If label == 0, first 1341 // hallucinate and epsilon loop, else use the underlying matcher to 1342 // search for the label or epsilons. 1343 // - Note since the ReplaceFST recursion on non-terminal arcs causes 1344 // epsilon transitions to be created we use the MultiEpsilonMatcher 1345 // to search for possible matches of non terminals. 1346 // - If the component Fst reaches a final state we also need to add 1347 // the exiting final arc. 1348 virtual bool Find_(Label label) { 1349 bool found = false; 1350 label_ = label; 1351 if (label_ == 0 || label_ == kNoLabel) { 1352 // Compute loop directly, saving Replace::ComputeArc 1353 if (label_ == 0) { 1354 current_loop_ = true; 1355 found = true; 1356 } 1357 // Search for matching multi epsilons 1358 final_arc_ = impl_->ComputeFinalArc(tuple_, 0); 1359 found = current_matcher_->Find(kNoLabel) || final_arc_ || found; 1360 } else { 1361 // Search on sub machine directly using sub machine matcher. 1362 found = current_matcher_->Find(label_); 1363 } 1364 return found; 1365 } 1366 1367 virtual bool Done_() const { 1368 return !current_loop_ && !final_arc_ && current_matcher_->Done(); 1369 } 1370 1371 virtual const Arc& Value_() const { 1372 if (current_loop_) { 1373 return loop_; 1374 } 1375 if (final_arc_) { 1376 impl_->ComputeFinalArc(tuple_, &arc_); 1377 return arc_; 1378 } 1379 const Arc& component_arc = current_matcher_->Value(); 1380 impl_->ComputeArc(tuple_, component_arc, &arc_); 1381 return arc_; 1382 } 1383 1384 virtual void Next_() { 1385 if (current_loop_) { 1386 current_loop_ = false; 1387 return; 1388 } 1389 if (final_arc_) { 1390 final_arc_ = false; 1391 return; 1392 } 1393 current_matcher_->Next(); 1394 } 1395 1396 const ReplaceFst<A, T>& fst_; 1397 ReplaceFstImpl<A, T> *impl_; 1398 LocalMatcher* current_matcher_; 1399 vector<LocalMatcher*> matcher_; 1400 1401 StateId s_; // Current state 1402 Label label_; // Current label 1403 1404 MatchType match_type_; // Supplied by caller 1405 mutable bool done_; 1406 mutable bool current_loop_; // Current arc is the implicit loop 1407 mutable bool final_arc_; // Current arc for exiting recursion 1408 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ 1409 mutable Arc arc_; 1410 Arc loop_; 1411 }; 1412 1413 template <class A, class T> inline 1414 void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const { 1415 data->base = new StateIterator< ReplaceFst<A, T> >(*this); 1416 } 1417 1418 typedef ReplaceFst<StdArc> StdReplaceFst; 1419 1420 1421 // // Recursivively replaces arcs in the root Fst with other Fsts. 1422 // This version writes the result of replacement to an output MutableFst. 1423 // 1424 // Replace supports replacement of arcs in one Fst with another 1425 // Fst. This replacement is recursive. Replace takes an array of 1426 // Fst(s). One Fst represents the root (or topology) machine. The root 1427 // Fst refers to other Fsts by recursively replacing arcs labeled as 1428 // non-terminals with the matching non-terminal Fst. Currently Replace 1429 // uses the output symbols of the arcs to determine whether the arc is 1430 // a non-terminal arc or not. A non-terminal can be any label that is 1431 // not a non-zero terminal label in the output alphabet. Note that 1432 // input argument is a vector of pair<>. These correspond to the tuple 1433 // of non-terminal Label and corresponding Fst. 1434 template<class Arc> 1435 void Replace(const vector<pair<typename Arc::Label, 1436 const Fst<Arc>* > >& ifst_array, 1437 MutableFst<Arc> *ofst, typename Arc::Label root, 1438 bool epsilon_on_replace) { 1439 ReplaceFstOptions<Arc> opts(root, epsilon_on_replace); 1440 opts.gc_limit = 0; // Cache only the last state for fastest copy. 1441 *ofst = ReplaceFst<Arc>(ifst_array, opts); 1442 } 1443 1444 template<class Arc> 1445 void Replace(const vector<pair<typename Arc::Label, 1446 const Fst<Arc>* > >& ifst_array, 1447 MutableFst<Arc> *ofst, typename Arc::Label root) { 1448 Replace(ifst_array, ofst, root, false); 1449 } 1450 1451 } // namespace fst 1452 1453 #endif // FST_LIB_REPLACE_H__ 1454