Home | History | Annotate | Download | only in lib
      1 // replace.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Functions and classes for the recursive replacement of Fsts.
     18 //
     19 
     20 #ifndef FST_LIB_REPLACE_H__
     21 #define FST_LIB_REPLACE_H__
     22 
     23 #include <ext/hash_map>
     24 using __gnu_cxx::hash_map;
     25 
     26 #include "fst/lib/fst.h"
     27 #include "fst/lib/cache.h"
     28 #include "fst/lib/test-properties.h"
     29 
     30 namespace fst {
     31 
     32 // By default ReplaceFst will copy the input label of the 'replace arc'.
     33 // For acceptors we do not want this behaviour. Instead we need to
     34 // create an epsilon arc when recursing into the appropriate Fst.
     35 // The epsilon_on_replace option can be used to toggle this behaviour.
     36 struct ReplaceFstOptions : CacheOptions {
     37   int64 root;    // root rule for expansion
     38   bool  epsilon_on_replace;
     39 
     40   ReplaceFstOptions(const CacheOptions &opts, int64 r)
     41       : CacheOptions(opts), root(r), epsilon_on_replace(false) {}
     42   explicit ReplaceFstOptions(int64 r)
     43       : root(r), epsilon_on_replace(false) {}
     44   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
     45       : root(r), epsilon_on_replace(epsilon_replace_arc) {}
     46   ReplaceFstOptions()
     47       : root(kNoLabel), epsilon_on_replace(false) {}
     48 };
     49 
     50 //
     51 // \class ReplaceFstImpl
     52 // \brief Implementation class for replace class Fst
     53 //
     54 // The replace implementation class supports a dynamic
     55 // expansion of a recursive transition network represented as Fst
     56 // with dynamic replacable arcs.
     57 //
     58 template <class A>
     59 class ReplaceFstImpl : public CacheImpl<A> {
     60  public:
     61   using FstImpl<A>::SetType;
     62   using FstImpl<A>::SetProperties;
     63   using FstImpl<A>::Properties;
     64   using FstImpl<A>::SetInputSymbols;
     65   using FstImpl<A>::SetOutputSymbols;
     66   using FstImpl<A>::InputSymbols;
     67   using FstImpl<A>::OutputSymbols;
     68 
     69   using CacheImpl<A>::HasStart;
     70   using CacheImpl<A>::HasArcs;
     71   using CacheImpl<A>::SetStart;
     72 
     73   typedef typename A::Label   Label;
     74   typedef typename A::Weight  Weight;
     75   typedef typename A::StateId StateId;
     76   typedef CacheState<A> State;
     77   typedef A Arc;
     78   typedef hash_map<Label, Label> NonTerminalHash;
     79 
     80 
     81   // \struct StateTuple
     82   // \brief Tuple of information that uniquely defines a state
     83   struct StateTuple {
     84     typedef int PrefixId;
     85 
     86     StateTuple() {}
     87     StateTuple(PrefixId p, StateId f, StateId s) :
     88         prefix_id(p), fst_id(f), fst_state(s) {}
     89 
     90     PrefixId prefix_id;  // index in prefix table
     91     StateId fst_id;      // current fst being walked
     92     StateId fst_state;   // current state in fst being walked, not to be
     93                          // confused with the state_id of the combined fst
     94   };
     95 
     96   // constructor for replace class implementation.
     97   // \param fst_tuples array of label/fst tuples, one for each non-terminal
     98   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
     99                  const ReplaceFstOptions &opts)
    100       : CacheImpl<A>(opts), opts_(opts) {
    101     SetType("replace");
    102     if (fst_tuples.size() > 0) {
    103       SetInputSymbols(fst_tuples[0].second->InputSymbols());
    104       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
    105     }
    106 
    107     fst_array_.push_back(0);
    108     for (size_t i = 0; i < fst_tuples.size(); ++i)
    109       AddFst(fst_tuples[i].first, fst_tuples[i].second);
    110 
    111     SetRoot(opts.root);
    112   }
    113 
    114   explicit ReplaceFstImpl(const ReplaceFstOptions &opts)
    115       : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) {
    116     fst_array_.push_back(0);
    117   }
    118 
    119   ReplaceFstImpl(const ReplaceFstImpl& impl)
    120       : opts_(impl.opts_), state_tuples_(impl.state_tuples_),
    121         state_hash_(impl.state_hash_),
    122         prefix_hash_(impl.prefix_hash_),
    123         stackprefix_array_(impl.stackprefix_array_),
    124         nonterminal_hash_(impl.nonterminal_hash_),
    125         root_(impl.root_) {
    126     SetType("replace");
    127     SetProperties(impl.Properties(), kCopyProperties);
    128     SetInputSymbols(InputSymbols());
    129     SetOutputSymbols(OutputSymbols());
    130     fst_array_.reserve(impl.fst_array_.size());
    131     fst_array_.push_back(0);
    132     for (size_t i = 1; i < impl.fst_array_.size(); ++i)
    133       fst_array_.push_back(impl.fst_array_[i]->Copy());
    134   }
    135 
    136   ~ReplaceFstImpl() {
    137     for (size_t i = 1; i < fst_array_.size(); ++i) {
    138       delete fst_array_[i];
    139     }
    140   }
    141 
    142   // Add to Fst array
    143   void AddFst(Label label, const Fst<A>* fst) {
    144     nonterminal_hash_[label] = fst_array_.size();
    145     fst_array_.push_back(fst->Copy());
    146     if (fst_array_.size() > 1) {
    147       vector<uint64> inprops(fst_array_.size());
    148 
    149       for (size_t i = 1; i < fst_array_.size(); ++i) {
    150         inprops[i] = fst_array_[i]->Properties(kCopyProperties, false);
    151       }
    152       SetProperties(ReplaceProperties(inprops));
    153 
    154       const SymbolTable* isymbols = fst_array_[1]->InputSymbols();
    155       const SymbolTable* osymbols = fst_array_[1]->OutputSymbols();
    156       for (size_t i = 2; i < fst_array_.size(); ++i) {
    157         if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) {
    158           LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1
    159                      << " does not match input symbols of base Fst (0'th fst)";
    160         }
    161         if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) {
    162           LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1
    163                      << " does not match output symbols of base Fst "
    164                      << "(0'th fst)";
    165         }
    166       }
    167     }
    168   }
    169 
    170   // Computes the dependency graph of the replace class and returns
    171   // true if the dependencies are cyclic. Cyclic dependencies will result
    172   // in an un-expandable replace fst.
    173   bool CyclicDependencies() const {
    174     StdVectorFst depfst;
    175 
    176     // one state for each fst
    177     for (size_t i = 1; i < fst_array_.size(); ++i)
    178       depfst.AddState();
    179 
    180     // an arc from each state (representing the fst) to the
    181     // state representing the fst being replaced
    182     for (size_t i = 1; i < fst_array_.size(); ++i) {
    183       for (StateIterator<Fst<A> > siter(*(fst_array_[i]));
    184            !siter.Done(); siter.Next()) {
    185         for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value());
    186              !aiter.Done(); aiter.Next()) {
    187           const A& arc = aiter.Value();
    188 
    189           typename NonTerminalHash::const_iterator it =
    190               nonterminal_hash_.find(arc.olabel);
    191           if (it != nonterminal_hash_.end()) {
    192             Label j = it->second - 1;
    193             depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j));
    194           }
    195         }
    196       }
    197     }
    198 
    199     depfst.SetStart(root_ - 1);
    200     depfst.SetFinal(root_ - 1, Weight::One());
    201     return depfst.Properties(kCyclic, true);
    202   }
    203 
    204   // set root rule for expansion
    205   void SetRoot(Label root) {
    206     Label nonterminal = nonterminal_hash_[root];
    207     root_ = (nonterminal > 0) ? nonterminal : 1;
    208   }
    209 
    210   // Change Fst array
    211   void SetFst(Label label, const Fst<A>* fst) {
    212     Label nonterminal = nonterminal_hash_[label];
    213     delete fst_array_[nonterminal];
    214     fst_array_[nonterminal] = fst->Copy();
    215   }
    216 
    217   // Return or compute start state of replace fst
    218   StateId Start() {
    219     if (!HasStart()) {
    220       if (fst_array_.size() == 1) {      // no fsts defined for replace
    221         SetStart(kNoStateId);
    222         return kNoStateId;
    223       } else {
    224         const Fst<A>* fst = fst_array_[root_];
    225         StateId fst_start = fst->Start();
    226         if (fst_start == kNoStateId)  // root Fst is empty
    227           return kNoStateId;
    228 
    229         int prefix = PrefixId(StackPrefix());
    230         StateId start = FindState(StateTuple(prefix, root_, fst_start));
    231         SetStart(start);
    232         return start;
    233       }
    234     } else {
    235       return CacheImpl<A>::Start();
    236     }
    237   }
    238 
    239   // return final weight of state (kInfWeight means state is not final)
    240   Weight Final(StateId s) {
    241     if (!HasFinal(s)) {
    242       const StateTuple& tuple  = state_tuples_[s];
    243       const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    244       const Fst<A>* fst = fst_array_[tuple.fst_id];
    245       StateId fst_state = tuple.fst_state;
    246 
    247       if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
    248         SetFinal(s, fst->Final(fst_state));
    249       else
    250         SetFinal(s, Weight::Zero());
    251     }
    252     return CacheImpl<A>::Final(s);
    253   }
    254 
    255   size_t NumArcs(StateId s) {
    256     if (!HasArcs(s))
    257       Expand(s);
    258     return CacheImpl<A>::NumArcs(s);
    259   }
    260 
    261   size_t NumInputEpsilons(StateId s) {
    262     if (!HasArcs(s))
    263       Expand(s);
    264     return CacheImpl<A>::NumInputEpsilons(s);
    265   }
    266 
    267   size_t NumOutputEpsilons(StateId s) {
    268     if (!HasArcs(s))
    269       Expand(s);
    270     return CacheImpl<A>::NumOutputEpsilons(s);
    271   }
    272 
    273   // return the base arc iterator, if arcs have not been computed yet,
    274   // extend/recurse for new arcs.
    275   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    276     if (!HasArcs(s))
    277       Expand(s);
    278     CacheImpl<A>::InitArcIterator(s, data);
    279   }
    280 
    281   // Find/create an Fst state given a StateTuple.  Only create a new
    282   // state if StateTuple is not found in the state hash.
    283   StateId FindState(const StateTuple& tuple) {
    284     typename StateTupleHash::iterator it = state_hash_.find(tuple);
    285     if (it == state_hash_.end()) {
    286       StateId new_state_id = state_tuples_.size();
    287       state_tuples_.push_back(tuple);
    288       state_hash_[tuple] = new_state_id;
    289       return new_state_id;
    290     } else {
    291       return it->second;
    292     }
    293   }
    294 
    295   // extend current state (walk arcs one level deep)
    296   void Expand(StateId s) {
    297     StateTuple tuple  = state_tuples_[s];
    298     const Fst<A>* fst = fst_array_[tuple.fst_id];
    299     StateId fst_state = tuple.fst_state;
    300     if (fst_state == kNoStateId) {
    301       SetArcs(s);
    302       return;
    303     }
    304 
    305     // if state is final, pop up stack
    306     const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    307     if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
    308       int prefix_id = PopPrefix(stack);
    309       const PrefixTuple& top = stack.Top();
    310 
    311       StateId nextstate =
    312         FindState(StateTuple(prefix_id, top.fst_id, top.nextstate));
    313       AddArc(s, A(0, 0, fst->Final(fst_state), nextstate));
    314     }
    315 
    316     // extend arcs leaving the state
    317     for (ArcIterator< Fst<A> > aiter(*fst, fst_state);
    318          !aiter.Done(); aiter.Next()) {
    319       const Arc& arc = aiter.Value();
    320       if (arc.olabel == 0) {  // expand local fst
    321         StateId nextstate =
    322           FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
    323         AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
    324       } else {
    325         // check for non terminal
    326         typename NonTerminalHash::const_iterator it =
    327             nonterminal_hash_.find(arc.olabel);
    328         if (it != nonterminal_hash_.end()) {  // recurse into non terminal
    329           Label nonterminal = it->second;
    330           const Fst<A>* nt_fst = fst_array_[nonterminal];
    331           int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
    332                                      tuple.fst_id, arc.nextstate);
    333 
    334           // if start state is valid replace, else arc is implicitly
    335           // deleted
    336           StateId nt_start = nt_fst->Start();
    337           if (nt_start != kNoStateId) {
    338             StateId nt_nextstate = FindState(
    339                 StateTuple(nt_prefix, nonterminal, nt_start));
    340             Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel;
    341             AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate));
    342           }
    343         } else {
    344           StateId nextstate =
    345             FindState(
    346                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
    347           AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
    348         }
    349       }
    350     }
    351 
    352     SetArcs(s);
    353   }
    354 
    355 
    356   // private helper classes
    357  private:
    358   static const int kPrime0 = 7853;
    359   static const int kPrime1 = 7867;
    360 
    361   // \class StateTupleEqual
    362   // \brief Compare two StateTuples for equality
    363   class StateTupleEqual {
    364    public:
    365     bool operator()(const StateTuple& x, const StateTuple& y) const {
    366       return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) &&
    367               (x.fst_state == y.fst_state));
    368     }
    369   };
    370 
    371   // \class StateTupleKey
    372   // \brief Hash function for StateTuple to Fst states
    373   class StateTupleKey {
    374    public:
    375     size_t operator()(const StateTuple& x) const {
    376       return static_cast<size_t>(x.prefix_id +
    377                                  x.fst_id * kPrime0 +
    378                                  x.fst_state * kPrime1);
    379     }
    380   };
    381 
    382   typedef hash_map<StateTuple, StateId, StateTupleKey, StateTupleEqual>
    383   StateTupleHash;
    384 
    385   // \class PrefixTuple
    386   // \brief Tuple of fst_id and destination state (entry in stack prefix)
    387   struct PrefixTuple {
    388     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
    389 
    390     Label   fst_id;
    391     StateId nextstate;
    392   };
    393 
    394   // \class StackPrefix
    395   // \brief Container for stack prefix.
    396   class StackPrefix {
    397    public:
    398     StackPrefix() {}
    399 
    400     // copy constructor
    401     StackPrefix(const StackPrefix& x) :
    402         prefix_(x.prefix_) {
    403     }
    404 
    405     void Push(int fst_id, StateId nextstate) {
    406       prefix_.push_back(PrefixTuple(fst_id, nextstate));
    407     }
    408 
    409     void Pop() {
    410       prefix_.pop_back();
    411     }
    412 
    413     const PrefixTuple& Top() const {
    414       return prefix_[prefix_.size()-1];
    415     }
    416 
    417     size_t Depth() const {
    418       return prefix_.size();
    419     }
    420 
    421    public:
    422     vector<PrefixTuple> prefix_;
    423   };
    424 
    425 
    426   // \class StackPrefixEqual
    427   // \brief Compare two stack prefix classes for equality
    428   class StackPrefixEqual {
    429    public:
    430     bool operator()(const StackPrefix& x, const StackPrefix& y) const {
    431       if (x.prefix_.size() != y.prefix_.size()) return false;
    432       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    433         if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
    434            x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
    435       }
    436       return true;
    437     }
    438   };
    439 
    440   //
    441   // \class StackPrefixKey
    442   // \brief Hash function for stack prefix to prefix id
    443   class StackPrefixKey {
    444    public:
    445     size_t operator()(const StackPrefix& x) const {
    446       int sum = 0;
    447       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    448         sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
    449       }
    450       return (size_t) sum;
    451     }
    452   };
    453 
    454   typedef hash_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual>
    455   StackPrefixHash;
    456 
    457   // private methods
    458  private:
    459   // hash stack prefix (return unique index into stackprefix array)
    460   int PrefixId(const StackPrefix& prefix) {
    461     typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
    462     if (it == prefix_hash_.end()) {
    463       int prefix_id = stackprefix_array_.size();
    464       stackprefix_array_.push_back(prefix);
    465       prefix_hash_[prefix] = prefix_id;
    466       return prefix_id;
    467     } else {
    468       return it->second;
    469     }
    470   }
    471 
    472   // prefix id after a stack pop
    473   int PopPrefix(StackPrefix prefix) {
    474     prefix.Pop();
    475     return PrefixId(prefix);
    476   }
    477 
    478   // prefix id after a stack push
    479   int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
    480     prefix.Push(fst_id, nextstate);
    481     return PrefixId(prefix);
    482   }
    483 
    484 
    485   // private data
    486  private:
    487   // runtime options
    488   ReplaceFstOptions opts_;
    489 
    490   // maps from StateId to StateTuple
    491   vector<StateTuple> state_tuples_;
    492 
    493   // hashes from StateTuple to StateId
    494   StateTupleHash state_hash_;
    495 
    496   // cross index of unique stack prefix
    497   // could potentially have one copy of prefix array
    498   StackPrefixHash prefix_hash_;
    499   vector<StackPrefix> stackprefix_array_;
    500 
    501   NonTerminalHash nonterminal_hash_;
    502   vector<const Fst<A>*> fst_array_;
    503 
    504   Label root_;
    505 
    506   void operator=(const ReplaceFstImpl<A> &);  // disallow
    507 };
    508 
    509 
    510 //
    511 // \class ReplaceFst
    512 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
    513 // This version is a delayed Fst.
    514 //
    515 // ReplaceFst supports dynamic replacement of arcs in one Fst with
    516 // another Fst. This replacement is recursive.  ReplaceFst can be used
    517 // to support a variety of delayed constructions such as recursive
    518 // transition networks, union, or closure.  It is constructed with an
    519 // array of Fst(s). One Fst represents the root (or topology)
    520 // machine. The root Fst refers to other Fsts by recursively replacing
    521 // arcs labeled as non-terminals with the matching non-terminal
    522 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
    523 // to determine whether the arc is a non-terminal arc or not. A
    524 // non-terminal can be any label that is not a non-zero terminal label
    525 // in the output alphabet.
    526 //
    527 // Note that the constructor uses a vector of pair<>. These correspond
    528 // to the tuple of non-terminal Label and corresponding Fst. For example
    529 // to implement the closure operation we need 2 Fsts. The first root
    530 // Fst is a single Arc on the start State that self loops, it references
    531 // the particular machine for which we are performing the closure operation.
    532 //
    533 template <class A>
    534 class ReplaceFst : public Fst<A> {
    535  public:
    536   friend class ArcIterator< ReplaceFst<A> >;
    537   friend class CacheStateIterator< ReplaceFst<A> >;
    538   friend class CacheArcIterator< ReplaceFst<A> >;
    539 
    540   typedef A Arc;
    541   typedef typename A::Label   Label;
    542   typedef typename A::Weight  Weight;
    543   typedef typename A::StateId StateId;
    544   typedef CacheState<A> State;
    545 
    546   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    547              Label root)
    548       : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {}
    549 
    550   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    551              const ReplaceFstOptions &opts)
    552       : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {}
    553 
    554   ReplaceFst(const ReplaceFst<A>& fst) :
    555       impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {}
    556 
    557   virtual ~ReplaceFst() {
    558     delete impl_;
    559   }
    560 
    561   virtual StateId Start() const {
    562     return impl_->Start();
    563   }
    564 
    565   virtual Weight Final(StateId s) const {
    566     return impl_->Final(s);
    567   }
    568 
    569   virtual size_t NumArcs(StateId s) const {
    570     return impl_->NumArcs(s);
    571   }
    572 
    573   virtual size_t NumInputEpsilons(StateId s) const {
    574     return impl_->NumInputEpsilons(s);
    575   }
    576 
    577   virtual size_t NumOutputEpsilons(StateId s) const {
    578     return impl_->NumOutputEpsilons(s);
    579   }
    580 
    581   virtual uint64 Properties(uint64 mask, bool test) const {
    582     if (test) {
    583       uint64 known, test = TestProperties(*this, mask, &known);
    584       impl_->SetProperties(test, known);
    585       return test & mask;
    586     } else {
    587       return impl_->Properties(mask);
    588     }
    589   }
    590 
    591   virtual const string& Type() const {
    592     return impl_->Type();
    593   }
    594 
    595   virtual ReplaceFst<A>* Copy() const {
    596     return new ReplaceFst<A>(*this);
    597   }
    598 
    599   virtual const SymbolTable* InputSymbols() const {
    600     return impl_->InputSymbols();
    601   }
    602 
    603   virtual const SymbolTable* OutputSymbols() const {
    604     return impl_->OutputSymbols();
    605   }
    606 
    607   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    608 
    609   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    610     impl_->InitArcIterator(s, data);
    611   }
    612 
    613   bool CyclicDependencies() const {
    614     return impl_->CyclicDependencies();
    615   }
    616 
    617  private:
    618   ReplaceFstImpl<A>* impl_;
    619 };
    620 
    621 
    622 // Specialization for ReplaceFst.
    623 template<class A>
    624 class StateIterator< ReplaceFst<A> >
    625     : public CacheStateIterator< ReplaceFst<A> > {
    626  public:
    627   explicit StateIterator(const ReplaceFst<A> &fst)
    628       : CacheStateIterator< ReplaceFst<A> >(fst) {}
    629 
    630  private:
    631   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    632 };
    633 
    634 // Specialization for ReplaceFst.
    635 template <class A>
    636 class ArcIterator< ReplaceFst<A> >
    637     : public CacheArcIterator< ReplaceFst<A> > {
    638  public:
    639   typedef typename A::StateId StateId;
    640 
    641   ArcIterator(const ReplaceFst<A> &fst, StateId s)
    642       : CacheArcIterator< ReplaceFst<A> >(fst, s) {
    643     if (!fst.impl_->HasArcs(s))
    644       fst.impl_->Expand(s);
    645   }
    646 
    647  private:
    648   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    649 };
    650 
    651 template <class A> inline
    652 void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    653   data->base = new StateIterator< ReplaceFst<A> >(*this);
    654 }
    655 
    656 typedef ReplaceFst<StdArc> StdReplaceFst;
    657 
    658 
    659 // // Recursivively replaces arcs in the root Fst with other Fsts.
    660 // This version writes the result of replacement to an output MutableFst.
    661 //
    662 // Replace supports replacement of arcs in one Fst with another
    663 // Fst. This replacement is recursive.  Replace takes an array of
    664 // Fst(s). One Fst represents the root (or topology) machine. The root
    665 // Fst refers to other Fsts by recursively replacing arcs labeled as
    666 // non-terminals with the matching non-terminal Fst. Currently Replace
    667 // uses the output symbols of the arcs to determine whether the arc is
    668 // a non-terminal arc or not. A non-terminal can be any label that is
    669 // not a non-zero terminal label in the output alphabet.  Note that
    670 // input argument is a vector of pair<>. These correspond to the tuple
    671 // of non-terminal Label and corresponding Fst.
    672 template<class Arc>
    673 void Replace(const vector<pair<typename Arc::Label,
    674              const Fst<Arc>* > >& ifst_array,
    675              MutableFst<Arc> *ofst, typename Arc::Label root) {
    676   ReplaceFstOptions opts(root);
    677   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
    678   *ofst = ReplaceFst<Arc>(ifst_array, opts);
    679 }
    680 
    681 }
    682 
    683 #endif  // FST_LIB_REPLACE_H__
    684