Home | History | Annotate | Download | only in lib
      1 // replace.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Functions and classes for the recursive replacement of Fsts.
     18 //
     19 
     20 #ifndef FST_LIB_REPLACE_H__
     21 #define FST_LIB_REPLACE_H__
     22 
     23 #include <unordered_map>
     24 
     25 #include "fst/lib/fst.h"
     26 #include "fst/lib/cache.h"
     27 #include "fst/lib/test-properties.h"
     28 
     29 namespace fst {
     30 
     31 // By default ReplaceFst will copy the input label of the 'replace arc'.
     32 // For acceptors we do not want this behaviour. Instead we need to
     33 // create an epsilon arc when recursing into the appropriate Fst.
     34 // The epsilon_on_replace option can be used to toggle this behaviour.
     35 struct ReplaceFstOptions : CacheOptions {
     36   int64 root;    // root rule for expansion
     37   bool  epsilon_on_replace;
     38 
     39   ReplaceFstOptions(const CacheOptions &opts, int64 r)
     40       : CacheOptions(opts), root(r), epsilon_on_replace(false) {}
     41   explicit ReplaceFstOptions(int64 r)
     42       : root(r), epsilon_on_replace(false) {}
     43   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
     44       : root(r), epsilon_on_replace(epsilon_replace_arc) {}
     45   ReplaceFstOptions()
     46       : root(kNoLabel), epsilon_on_replace(false) {}
     47 };
     48 
     49 //
     50 // \class ReplaceFstImpl
     51 // \brief Implementation class for replace class Fst
     52 //
     53 // The replace implementation class supports a dynamic
     54 // expansion of a recursive transition network represented as Fst
     55 // with dynamic replacable arcs.
     56 //
     57 template <class A>
     58 class ReplaceFstImpl : public CacheImpl<A> {
     59  public:
     60   using FstImpl<A>::SetType;
     61   using FstImpl<A>::SetProperties;
     62   using FstImpl<A>::Properties;
     63   using FstImpl<A>::SetInputSymbols;
     64   using FstImpl<A>::SetOutputSymbols;
     65   using FstImpl<A>::InputSymbols;
     66   using FstImpl<A>::OutputSymbols;
     67 
     68   using CacheImpl<A>::HasStart;
     69   using CacheImpl<A>::HasArcs;
     70   using CacheImpl<A>::SetStart;
     71 
     72   typedef typename A::Label   Label;
     73   typedef typename A::Weight  Weight;
     74   typedef typename A::StateId StateId;
     75   typedef CacheState<A> State;
     76   typedef A Arc;
     77   typedef std::unordered_map<Label, Label> NonTerminalHash;
     78 
     79 
     80   // \struct StateTuple
     81   // \brief Tuple of information that uniquely defines a state
     82   struct StateTuple {
     83     typedef int PrefixId;
     84 
     85     StateTuple() {}
     86     StateTuple(PrefixId p, StateId f, StateId s) :
     87         prefix_id(p), fst_id(f), fst_state(s) {}
     88 
     89     PrefixId prefix_id;  // index in prefix table
     90     StateId fst_id;      // current fst being walked
     91     StateId fst_state;   // current state in fst being walked, not to be
     92                          // confused with the state_id of the combined fst
     93   };
     94 
     95   // constructor for replace class implementation.
     96   // \param fst_tuples array of label/fst tuples, one for each non-terminal
     97   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
     98                  const ReplaceFstOptions &opts)
     99       : CacheImpl<A>(opts), opts_(opts) {
    100     SetType("replace");
    101     if (fst_tuples.size() > 0) {
    102       SetInputSymbols(fst_tuples[0].second->InputSymbols());
    103       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
    104     }
    105 
    106     fst_array_.push_back(0);
    107     for (size_t i = 0; i < fst_tuples.size(); ++i)
    108       AddFst(fst_tuples[i].first, fst_tuples[i].second);
    109 
    110     SetRoot(opts.root);
    111   }
    112 
    113   explicit ReplaceFstImpl(const ReplaceFstOptions &opts)
    114       : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) {
    115     fst_array_.push_back(0);
    116   }
    117 
    118   ReplaceFstImpl(const ReplaceFstImpl& impl)
    119       : opts_(impl.opts_), state_tuples_(impl.state_tuples_),
    120         state_hash_(impl.state_hash_),
    121         prefix_hash_(impl.prefix_hash_),
    122         stackprefix_array_(impl.stackprefix_array_),
    123         nonterminal_hash_(impl.nonterminal_hash_),
    124         root_(impl.root_) {
    125     SetType("replace");
    126     SetProperties(impl.Properties(), kCopyProperties);
    127     SetInputSymbols(InputSymbols());
    128     SetOutputSymbols(OutputSymbols());
    129     fst_array_.reserve(impl.fst_array_.size());
    130     fst_array_.push_back(0);
    131     for (size_t i = 1; i < impl.fst_array_.size(); ++i)
    132       fst_array_.push_back(impl.fst_array_[i]->Copy());
    133   }
    134 
    135   ~ReplaceFstImpl() {
    136     for (size_t i = 1; i < fst_array_.size(); ++i) {
    137       delete fst_array_[i];
    138     }
    139   }
    140 
    141   // Add to Fst array
    142   void AddFst(Label label, const Fst<A>* fst) {
    143     nonterminal_hash_[label] = fst_array_.size();
    144     fst_array_.push_back(fst->Copy());
    145     if (fst_array_.size() > 1) {
    146       vector<uint64> inprops(fst_array_.size());
    147 
    148       for (size_t i = 1; i < fst_array_.size(); ++i) {
    149         inprops[i] = fst_array_[i]->Properties(kCopyProperties, false);
    150       }
    151       SetProperties(ReplaceProperties(inprops));
    152 
    153       const SymbolTable* isymbols = fst_array_[1]->InputSymbols();
    154       const SymbolTable* osymbols = fst_array_[1]->OutputSymbols();
    155       for (size_t i = 2; i < fst_array_.size(); ++i) {
    156         if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) {
    157           LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1
    158                      << " does not match input symbols of base Fst (0'th fst)";
    159         }
    160         if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) {
    161           LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1
    162                      << " does not match output symbols of base Fst "
    163                      << "(0'th fst)";
    164         }
    165       }
    166     }
    167   }
    168 
    169   // Computes the dependency graph of the replace class and returns
    170   // true if the dependencies are cyclic. Cyclic dependencies will result
    171   // in an un-expandable replace fst.
    172   bool CyclicDependencies() const {
    173     StdVectorFst depfst;
    174 
    175     // one state for each fst
    176     for (size_t i = 1; i < fst_array_.size(); ++i)
    177       depfst.AddState();
    178 
    179     // an arc from each state (representing the fst) to the
    180     // state representing the fst being replaced
    181     for (size_t i = 1; i < fst_array_.size(); ++i) {
    182       for (StateIterator<Fst<A> > siter(*(fst_array_[i]));
    183            !siter.Done(); siter.Next()) {
    184         for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value());
    185              !aiter.Done(); aiter.Next()) {
    186           const A& arc = aiter.Value();
    187 
    188           typename NonTerminalHash::const_iterator it =
    189               nonterminal_hash_.find(arc.olabel);
    190           if (it != nonterminal_hash_.end()) {
    191             Label j = it->second - 1;
    192             depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j));
    193           }
    194         }
    195       }
    196     }
    197 
    198     depfst.SetStart(root_ - 1);
    199     depfst.SetFinal(root_ - 1, Weight::One());
    200     return depfst.Properties(kCyclic, true);
    201   }
    202 
    203   // set root rule for expansion
    204   void SetRoot(Label root) {
    205     Label nonterminal = nonterminal_hash_[root];
    206     root_ = (nonterminal > 0) ? nonterminal : 1;
    207   }
    208 
    209   // Change Fst array
    210   void SetFst(Label label, const Fst<A>* fst) {
    211     Label nonterminal = nonterminal_hash_[label];
    212     delete fst_array_[nonterminal];
    213     fst_array_[nonterminal] = fst->Copy();
    214   }
    215 
    216   // Return or compute start state of replace fst
    217   StateId Start() {
    218     if (!HasStart()) {
    219       if (fst_array_.size() == 1) {      // no fsts defined for replace
    220         SetStart(kNoStateId);
    221         return kNoStateId;
    222       } else {
    223         const Fst<A>* fst = fst_array_[root_];
    224         StateId fst_start = fst->Start();
    225         if (fst_start == kNoStateId)  // root Fst is empty
    226           return kNoStateId;
    227 
    228         int prefix = PrefixId(StackPrefix());
    229         StateId start = FindState(StateTuple(prefix, root_, fst_start));
    230         SetStart(start);
    231         return start;
    232       }
    233     } else {
    234       return CacheImpl<A>::Start();
    235     }
    236   }
    237 
    238   // return final weight of state (kInfWeight means state is not final)
    239   Weight Final(StateId s) {
    240     if (!HasFinal(s)) {
    241       const StateTuple& tuple  = state_tuples_[s];
    242       const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    243       const Fst<A>* fst = fst_array_[tuple.fst_id];
    244       StateId fst_state = tuple.fst_state;
    245 
    246       if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
    247         SetFinal(s, fst->Final(fst_state));
    248       else
    249         SetFinal(s, Weight::Zero());
    250     }
    251     return CacheImpl<A>::Final(s);
    252   }
    253 
    254   size_t NumArcs(StateId s) {
    255     if (!HasArcs(s))
    256       Expand(s);
    257     return CacheImpl<A>::NumArcs(s);
    258   }
    259 
    260   size_t NumInputEpsilons(StateId s) {
    261     if (!HasArcs(s))
    262       Expand(s);
    263     return CacheImpl<A>::NumInputEpsilons(s);
    264   }
    265 
    266   size_t NumOutputEpsilons(StateId s) {
    267     if (!HasArcs(s))
    268       Expand(s);
    269     return CacheImpl<A>::NumOutputEpsilons(s);
    270   }
    271 
    272   // return the base arc iterator, if arcs have not been computed yet,
    273   // extend/recurse for new arcs.
    274   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    275     if (!HasArcs(s))
    276       Expand(s);
    277     CacheImpl<A>::InitArcIterator(s, data);
    278   }
    279 
    280   // Find/create an Fst state given a StateTuple.  Only create a new
    281   // state if StateTuple is not found in the state hash.
    282   StateId FindState(const StateTuple& tuple) {
    283     typename StateTupleHash::iterator it = state_hash_.find(tuple);
    284     if (it == state_hash_.end()) {
    285       StateId new_state_id = state_tuples_.size();
    286       state_tuples_.push_back(tuple);
    287       state_hash_[tuple] = new_state_id;
    288       return new_state_id;
    289     } else {
    290       return it->second;
    291     }
    292   }
    293 
    294   // extend current state (walk arcs one level deep)
    295   void Expand(StateId s) {
    296     StateTuple tuple  = state_tuples_[s];
    297     const Fst<A>* fst = fst_array_[tuple.fst_id];
    298     StateId fst_state = tuple.fst_state;
    299     if (fst_state == kNoStateId) {
    300       SetArcs(s);
    301       return;
    302     }
    303 
    304     // if state is final, pop up stack
    305     const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    306     if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
    307       int prefix_id = PopPrefix(stack);
    308       const PrefixTuple& top = stack.Top();
    309 
    310       StateId nextstate =
    311         FindState(StateTuple(prefix_id, top.fst_id, top.nextstate));
    312       AddArc(s, A(0, 0, fst->Final(fst_state), nextstate));
    313     }
    314 
    315     // extend arcs leaving the state
    316     for (ArcIterator< Fst<A> > aiter(*fst, fst_state);
    317          !aiter.Done(); aiter.Next()) {
    318       const Arc& arc = aiter.Value();
    319       if (arc.olabel == 0) {  // expand local fst
    320         StateId nextstate =
    321           FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
    322         AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
    323       } else {
    324         // check for non terminal
    325         typename NonTerminalHash::const_iterator it =
    326             nonterminal_hash_.find(arc.olabel);
    327         if (it != nonterminal_hash_.end()) {  // recurse into non terminal
    328           Label nonterminal = it->second;
    329           const Fst<A>* nt_fst = fst_array_[nonterminal];
    330           int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
    331                                      tuple.fst_id, arc.nextstate);
    332 
    333           // if start state is valid replace, else arc is implicitly
    334           // deleted
    335           StateId nt_start = nt_fst->Start();
    336           if (nt_start != kNoStateId) {
    337             StateId nt_nextstate = FindState(
    338                 StateTuple(nt_prefix, nonterminal, nt_start));
    339             Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel;
    340             AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate));
    341           }
    342         } else {
    343           StateId nextstate =
    344             FindState(
    345                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
    346           AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
    347         }
    348       }
    349     }
    350 
    351     SetArcs(s);
    352   }
    353 
    354 
    355   // private helper classes
    356  private:
    357   static const int kPrime0 = 7853;
    358   static const int kPrime1 = 7867;
    359 
    360   // \class StateTupleEqual
    361   // \brief Compare two StateTuples for equality
    362   class StateTupleEqual {
    363    public:
    364     bool operator()(const StateTuple& x, const StateTuple& y) const {
    365       return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) &&
    366               (x.fst_state == y.fst_state));
    367     }
    368   };
    369 
    370   // \class StateTupleKey
    371   // \brief Hash function for StateTuple to Fst states
    372   class StateTupleKey {
    373    public:
    374     size_t operator()(const StateTuple& x) const {
    375       return static_cast<size_t>(x.prefix_id +
    376                                  x.fst_id * kPrime0 +
    377                                  x.fst_state * kPrime1);
    378     }
    379   };
    380 
    381   typedef std::unordered_map<StateTuple, StateId, StateTupleKey, StateTupleEqual>
    382   StateTupleHash;
    383 
    384   // \class PrefixTuple
    385   // \brief Tuple of fst_id and destination state (entry in stack prefix)
    386   struct PrefixTuple {
    387     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
    388 
    389     Label   fst_id;
    390     StateId nextstate;
    391   };
    392 
    393   // \class StackPrefix
    394   // \brief Container for stack prefix.
    395   class StackPrefix {
    396    public:
    397     StackPrefix() {}
    398 
    399     // copy constructor
    400     StackPrefix(const StackPrefix& x) :
    401         prefix_(x.prefix_) {
    402     }
    403 
    404     void Push(int fst_id, StateId nextstate) {
    405       prefix_.push_back(PrefixTuple(fst_id, nextstate));
    406     }
    407 
    408     void Pop() {
    409       prefix_.pop_back();
    410     }
    411 
    412     const PrefixTuple& Top() const {
    413       return prefix_[prefix_.size()-1];
    414     }
    415 
    416     size_t Depth() const {
    417       return prefix_.size();
    418     }
    419 
    420    public:
    421     vector<PrefixTuple> prefix_;
    422   };
    423 
    424 
    425   // \class StackPrefixEqual
    426   // \brief Compare two stack prefix classes for equality
    427   class StackPrefixEqual {
    428    public:
    429     bool operator()(const StackPrefix& x, const StackPrefix& y) const {
    430       if (x.prefix_.size() != y.prefix_.size()) return false;
    431       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    432         if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
    433            x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
    434       }
    435       return true;
    436     }
    437   };
    438 
    439   //
    440   // \class StackPrefixKey
    441   // \brief Hash function for stack prefix to prefix id
    442   class StackPrefixKey {
    443    public:
    444     size_t operator()(const StackPrefix& x) const {
    445       int sum = 0;
    446       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    447         sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
    448       }
    449       return (size_t) sum;
    450     }
    451   };
    452 
    453   typedef std::unordered_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual>
    454   StackPrefixHash;
    455 
    456   // private methods
    457  private:
    458   // hash stack prefix (return unique index into stackprefix array)
    459   int PrefixId(const StackPrefix& prefix) {
    460     typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
    461     if (it == prefix_hash_.end()) {
    462       int prefix_id = stackprefix_array_.size();
    463       stackprefix_array_.push_back(prefix);
    464       prefix_hash_[prefix] = prefix_id;
    465       return prefix_id;
    466     } else {
    467       return it->second;
    468     }
    469   }
    470 
    471   // prefix id after a stack pop
    472   int PopPrefix(StackPrefix prefix) {
    473     prefix.Pop();
    474     return PrefixId(prefix);
    475   }
    476 
    477   // prefix id after a stack push
    478   int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
    479     prefix.Push(fst_id, nextstate);
    480     return PrefixId(prefix);
    481   }
    482 
    483 
    484   // private data
    485  private:
    486   // runtime options
    487   ReplaceFstOptions opts_;
    488 
    489   // maps from StateId to StateTuple
    490   vector<StateTuple> state_tuples_;
    491 
    492   // hashes from StateTuple to StateId
    493   StateTupleHash state_hash_;
    494 
    495   // cross index of unique stack prefix
    496   // could potentially have one copy of prefix array
    497   StackPrefixHash prefix_hash_;
    498   vector<StackPrefix> stackprefix_array_;
    499 
    500   NonTerminalHash nonterminal_hash_;
    501   vector<const Fst<A>*> fst_array_;
    502 
    503   Label root_;
    504 
    505   void operator=(const ReplaceFstImpl<A> &);  // disallow
    506 };
    507 
    508 
    509 //
    510 // \class ReplaceFst
    511 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
    512 // This version is a delayed Fst.
    513 //
    514 // ReplaceFst supports dynamic replacement of arcs in one Fst with
    515 // another Fst. This replacement is recursive.  ReplaceFst can be used
    516 // to support a variety of delayed constructions such as recursive
    517 // transition networks, union, or closure.  It is constructed with an
    518 // array of Fst(s). One Fst represents the root (or topology)
    519 // machine. The root Fst refers to other Fsts by recursively replacing
    520 // arcs labeled as non-terminals with the matching non-terminal
    521 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
    522 // to determine whether the arc is a non-terminal arc or not. A
    523 // non-terminal can be any label that is not a non-zero terminal label
    524 // in the output alphabet.
    525 //
    526 // Note that the constructor uses a vector of pair<>. These correspond
    527 // to the tuple of non-terminal Label and corresponding Fst. For example
    528 // to implement the closure operation we need 2 Fsts. The first root
    529 // Fst is a single Arc on the start State that self loops, it references
    530 // the particular machine for which we are performing the closure operation.
    531 //
    532 template <class A>
    533 class ReplaceFst : public Fst<A> {
    534  public:
    535   friend class ArcIterator< ReplaceFst<A> >;
    536   friend class CacheStateIterator< ReplaceFst<A> >;
    537   friend class CacheArcIterator< ReplaceFst<A> >;
    538 
    539   typedef A Arc;
    540   typedef typename A::Label   Label;
    541   typedef typename A::Weight  Weight;
    542   typedef typename A::StateId StateId;
    543   typedef CacheState<A> State;
    544 
    545   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    546              Label root)
    547       : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {}
    548 
    549   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    550              const ReplaceFstOptions &opts)
    551       : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {}
    552 
    553   ReplaceFst(const ReplaceFst<A>& fst) :
    554       impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {}
    555 
    556   virtual ~ReplaceFst() {
    557     delete impl_;
    558   }
    559 
    560   virtual StateId Start() const {
    561     return impl_->Start();
    562   }
    563 
    564   virtual Weight Final(StateId s) const {
    565     return impl_->Final(s);
    566   }
    567 
    568   virtual size_t NumArcs(StateId s) const {
    569     return impl_->NumArcs(s);
    570   }
    571 
    572   virtual size_t NumInputEpsilons(StateId s) const {
    573     return impl_->NumInputEpsilons(s);
    574   }
    575 
    576   virtual size_t NumOutputEpsilons(StateId s) const {
    577     return impl_->NumOutputEpsilons(s);
    578   }
    579 
    580   virtual uint64 Properties(uint64 mask, bool test) const {
    581     if (test) {
    582       uint64 known, test = TestProperties(*this, mask, &known);
    583       impl_->SetProperties(test, known);
    584       return test & mask;
    585     } else {
    586       return impl_->Properties(mask);
    587     }
    588   }
    589 
    590   virtual const string& Type() const {
    591     return impl_->Type();
    592   }
    593 
    594   virtual ReplaceFst<A>* Copy() const {
    595     return new ReplaceFst<A>(*this);
    596   }
    597 
    598   virtual const SymbolTable* InputSymbols() const {
    599     return impl_->InputSymbols();
    600   }
    601 
    602   virtual const SymbolTable* OutputSymbols() const {
    603     return impl_->OutputSymbols();
    604   }
    605 
    606   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    607 
    608   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    609     impl_->InitArcIterator(s, data);
    610   }
    611 
    612   bool CyclicDependencies() const {
    613     return impl_->CyclicDependencies();
    614   }
    615 
    616  private:
    617   ReplaceFstImpl<A>* impl_;
    618 };
    619 
    620 
    621 // Specialization for ReplaceFst.
    622 template<class A>
    623 class StateIterator< ReplaceFst<A> >
    624     : public CacheStateIterator< ReplaceFst<A> > {
    625  public:
    626   explicit StateIterator(const ReplaceFst<A> &fst)
    627       : CacheStateIterator< ReplaceFst<A> >(fst) {}
    628 
    629  private:
    630   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    631 };
    632 
    633 // Specialization for ReplaceFst.
    634 template <class A>
    635 class ArcIterator< ReplaceFst<A> >
    636     : public CacheArcIterator< ReplaceFst<A> > {
    637  public:
    638   typedef typename A::StateId StateId;
    639 
    640   ArcIterator(const ReplaceFst<A> &fst, StateId s)
    641       : CacheArcIterator< ReplaceFst<A> >(fst, s) {
    642     if (!fst.impl_->HasArcs(s))
    643       fst.impl_->Expand(s);
    644   }
    645 
    646  private:
    647   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    648 };
    649 
    650 template <class A> inline
    651 void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    652   data->base = new StateIterator< ReplaceFst<A> >(*this);
    653 }
    654 
    655 typedef ReplaceFst<StdArc> StdReplaceFst;
    656 
    657 
    658 // // Recursivively replaces arcs in the root Fst with other Fsts.
    659 // This version writes the result of replacement to an output MutableFst.
    660 //
    661 // Replace supports replacement of arcs in one Fst with another
    662 // Fst. This replacement is recursive.  Replace takes an array of
    663 // Fst(s). One Fst represents the root (or topology) machine. The root
    664 // Fst refers to other Fsts by recursively replacing arcs labeled as
    665 // non-terminals with the matching non-terminal Fst. Currently Replace
    666 // uses the output symbols of the arcs to determine whether the arc is
    667 // a non-terminal arc or not. A non-terminal can be any label that is
    668 // not a non-zero terminal label in the output alphabet.  Note that
    669 // input argument is a vector of pair<>. These correspond to the tuple
    670 // of non-terminal Label and corresponding Fst.
    671 template<class Arc>
    672 void Replace(const vector<pair<typename Arc::Label,
    673              const Fst<Arc>* > >& ifst_array,
    674              MutableFst<Arc> *ofst, typename Arc::Label root) {
    675   ReplaceFstOptions opts(root);
    676   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
    677   *ofst = ReplaceFst<Arc>(ifst_array, opts);
    678 }
    679 
    680 }
    681 
    682 #endif  // FST_LIB_REPLACE_H__
    683