Home | History | Annotate | Download | only in fst
      1 // replace.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: johans (at) google.com (Johan Schalkwyk)
     17 //
     18 // \file
     19 // Functions and classes for the recursive replacement of Fsts.
     20 //
     21 
     22 #ifndef FST_LIB_REPLACE_H__
     23 #define FST_LIB_REPLACE_H__
     24 
     25 #include <tr1/unordered_map>
     26 using std::tr1::unordered_map;
     27 using std::tr1::unordered_multimap;
     28 #include <set>
     29 #include <string>
     30 #include <utility>
     31 using std::pair; using std::make_pair;
     32 #include <vector>
     33 using std::vector;
     34 
     35 #include <fst/cache.h>
     36 #include <fst/expanded-fst.h>
     37 #include <fst/fst.h>
     38 #include <fst/matcher.h>
     39 #include <fst/replace-util.h>
     40 #include <fst/state-table.h>
     41 #include <fst/test-properties.h>
     42 
     43 namespace fst {
     44 
     45 //
     46 // REPLACE STATE TUPLES AND TABLES
     47 //
     48 // The replace state table has the form
     49 //
     50 // template <class A, class P>
     51 // class ReplaceStateTable {
     52 //  public:
     53 //   typedef A Arc;
     54 //   typedef P PrefixId;
     55 //   typedef typename A::StateId StateId;
     56 //   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
     57 //   typedef typename A::Label Label;
     58 //
     59 //   // Required constuctor
     60 //   ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
     61 //                     Label root);
     62 //
     63 //   // Required copy constructor that does not copy state
     64 //   ReplaceStateTable(const ReplaceStateTable<A,P> &table);
     65 //
     66 //   // Lookup state ID by tuple. If it doesn't exist, then add it.
     67 //   StateId FindState(const StateTuple &tuple);
     68 //
     69 //   // Lookup state tuple by ID.
     70 //   const StateTuple &Tuple(StateId id) const;
     71 // };
     72 
     73 
     74 // \struct ReplaceStateTuple
     75 // \brief Tuple of information that uniquely defines a state in replace
     76 template <class S, class P>
     77 struct ReplaceStateTuple {
     78   typedef S StateId;
     79   typedef P PrefixId;
     80 
     81   ReplaceStateTuple()
     82       : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
     83 
     84   ReplaceStateTuple(PrefixId p, StateId f, StateId s)
     85       : prefix_id(p), fst_id(f), fst_state(s) {}
     86 
     87   PrefixId prefix_id;  // index in prefix table
     88   StateId fst_id;      // current fst being walked
     89   StateId fst_state;   // current state in fst being walked, not to be
     90                        // confused with the state_id of the combined fst
     91 };
     92 
     93 
     94 // Equality of replace state tuples.
     95 template <class S, class P>
     96 inline bool operator==(const ReplaceStateTuple<S, P>& x,
     97                        const ReplaceStateTuple<S, P>& y) {
     98   return x.prefix_id == y.prefix_id &&
     99       x.fst_id == y.fst_id &&
    100       x.fst_state == y.fst_state;
    101 }
    102 
    103 
    104 // \class ReplaceRootSelector
    105 // Functor returning true for tuples corresponding to states in the root FST
    106 template <class S, class P>
    107 class ReplaceRootSelector {
    108  public:
    109   bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
    110     return tuple.prefix_id == 0;
    111   }
    112 };
    113 
    114 
    115 // \class ReplaceFingerprint
    116 // Fingerprint for general replace state tuples.
    117 template <class S, class P>
    118 class ReplaceFingerprint {
    119  public:
    120   ReplaceFingerprint(const vector<uint64> *size_array)
    121       : cumulative_size_array_(size_array) {}
    122 
    123   uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
    124     return tuple.prefix_id * (cumulative_size_array_->back()) +
    125         cumulative_size_array_->at(tuple.fst_id - 1) +
    126         tuple.fst_state;
    127   }
    128 
    129  private:
    130   const vector<uint64> *cumulative_size_array_;
    131 };
    132 
    133 
    134 // \class ReplaceFstStateFingerprint
    135 // Useful when the fst_state uniquely define the tuple.
    136 template <class S, class P>
    137 class ReplaceFstStateFingerprint {
    138  public:
    139   uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
    140     return tuple.fst_state;
    141   }
    142 };
    143 
    144 
    145 // \class ReplaceHash
    146 // A generic hash function for replace state tuples.
    147 template <typename S, typename P>
    148 class ReplaceHash {
    149  public:
    150   size_t operator()(const ReplaceStateTuple<S, P>& t) const {
    151     return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
    152   }
    153  private:
    154   static const size_t kPrime0;
    155   static const size_t kPrime1;
    156 };
    157 
    158 template <typename S, typename P>
    159 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
    160 
    161 template <typename S, typename P>
    162 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
    163 
    164 template <class A, class T> class ReplaceFstMatcher;
    165 
    166 
    167 // \class VectorHashReplaceStateTable
    168 // A two-level state table for replace.
    169 // Warning: calls CountStates to compute the number of states of each
    170 // component Fst.
    171 template <class A, class P = ssize_t>
    172 class VectorHashReplaceStateTable {
    173  public:
    174   typedef A Arc;
    175   typedef typename A::StateId StateId;
    176   typedef typename A::Label Label;
    177   typedef P PrefixId;
    178   typedef ReplaceStateTuple<StateId, P> StateTuple;
    179   typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
    180                                ReplaceRootSelector<StateId, P>,
    181                                ReplaceFstStateFingerprint<StateId, P>,
    182                                ReplaceFingerprint<StateId, P> > StateTable;
    183 
    184   VectorHashReplaceStateTable(
    185       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
    186       Label root) : root_size_(0) {
    187     cumulative_size_array_.push_back(0);
    188     for (size_t i = 0; i < fst_tuples.size(); ++i) {
    189       if (fst_tuples[i].first == root) {
    190         root_size_ = CountStates(*(fst_tuples[i].second));
    191         cumulative_size_array_.push_back(cumulative_size_array_.back());
    192       } else {
    193         cumulative_size_array_.push_back(cumulative_size_array_.back() +
    194                                          CountStates(*(fst_tuples[i].second)));
    195       }
    196     }
    197     state_table_ = new StateTable(
    198         new ReplaceRootSelector<StateId, P>,
    199         new ReplaceFstStateFingerprint<StateId, P>,
    200         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
    201         root_size_,
    202         root_size_ + cumulative_size_array_.back());
    203   }
    204 
    205   VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
    206       : root_size_(table.root_size_),
    207         cumulative_size_array_(table.cumulative_size_array_) {
    208     state_table_ = new StateTable(
    209         new ReplaceRootSelector<StateId, P>,
    210         new ReplaceFstStateFingerprint<StateId, P>,
    211         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
    212         root_size_,
    213         root_size_ + cumulative_size_array_.back());
    214   }
    215 
    216   ~VectorHashReplaceStateTable() {
    217     delete state_table_;
    218   }
    219 
    220   StateId FindState(const StateTuple &tuple) {
    221     return state_table_->FindState(tuple);
    222   }
    223 
    224   const StateTuple &Tuple(StateId id) const {
    225     return state_table_->Tuple(id);
    226   }
    227 
    228  private:
    229   StateId root_size_;
    230   vector<uint64> cumulative_size_array_;
    231   StateTable *state_table_;
    232 };
    233 
    234 
    235 // \class DefaultReplaceStateTable
    236 // Default replace state table
    237 template <class A, class P = ssize_t>
    238 class DefaultReplaceStateTable : public CompactHashStateTable<
    239   ReplaceStateTuple<typename A::StateId, P>,
    240   ReplaceHash<typename A::StateId, P> > {
    241  public:
    242   typedef A Arc;
    243   typedef typename A::StateId StateId;
    244   typedef typename A::Label Label;
    245   typedef P PrefixId;
    246   typedef ReplaceStateTuple<StateId, P> StateTuple;
    247   typedef CompactHashStateTable<StateTuple,
    248                                 ReplaceHash<StateId, PrefixId> > StateTable;
    249 
    250   using StateTable::FindState;
    251   using StateTable::Tuple;
    252 
    253   DefaultReplaceStateTable(
    254       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
    255       Label root) {}
    256 
    257   DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
    258       : StateTable() {}
    259 };
    260 
    261 //
    262 // REPLACE FST CLASS
    263 //
    264 
    265 // By default ReplaceFst will copy the input label of the 'replace arc'.
    266 // For acceptors we do not want this behaviour. Instead we need to
    267 // create an epsilon arc when recursing into the appropriate Fst.
    268 // The 'epsilon_on_replace' option can be used to toggle this behaviour.
    269 template <class A, class T = DefaultReplaceStateTable<A> >
    270 struct ReplaceFstOptions : CacheOptions {
    271   int64 root;    // root rule for expansion
    272   bool  epsilon_on_replace;
    273   bool  take_ownership;  // take ownership of input Fst(s)
    274   T*    state_table;
    275 
    276   ReplaceFstOptions(const CacheOptions &opts, int64 r)
    277       : CacheOptions(opts),
    278         root(r),
    279         epsilon_on_replace(false),
    280         take_ownership(false),
    281         state_table(0) {}
    282   explicit ReplaceFstOptions(int64 r)
    283       : root(r),
    284         epsilon_on_replace(false),
    285         take_ownership(false),
    286         state_table(0) {}
    287   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
    288       : root(r),
    289         epsilon_on_replace(epsilon_replace_arc),
    290         take_ownership(false),
    291         state_table(0) {}
    292   ReplaceFstOptions()
    293       : root(kNoLabel),
    294         epsilon_on_replace(false),
    295         take_ownership(false),
    296         state_table(0) {}
    297 };
    298 
    299 
    300 // \class ReplaceFstImpl
    301 // \brief Implementation class for replace class Fst
    302 //
    303 // The replace implementation class supports a dynamic
    304 // expansion of a recursive transition network represented as Fst
    305 // with dynamic replacable arcs.
    306 //
    307 template <class A, class T>
    308 class ReplaceFstImpl : public CacheImpl<A> {
    309   friend class ReplaceFstMatcher<A, T>;
    310 
    311  public:
    312   using FstImpl<A>::SetType;
    313   using FstImpl<A>::SetProperties;
    314   using FstImpl<A>::WriteHeader;
    315   using FstImpl<A>::SetInputSymbols;
    316   using FstImpl<A>::SetOutputSymbols;
    317   using FstImpl<A>::InputSymbols;
    318   using FstImpl<A>::OutputSymbols;
    319 
    320   using CacheImpl<A>::PushArc;
    321   using CacheImpl<A>::HasArcs;
    322   using CacheImpl<A>::HasFinal;
    323   using CacheImpl<A>::HasStart;
    324   using CacheImpl<A>::SetArcs;
    325   using CacheImpl<A>::SetFinal;
    326   using CacheImpl<A>::SetStart;
    327 
    328   typedef typename A::Label   Label;
    329   typedef typename A::Weight  Weight;
    330   typedef typename A::StateId StateId;
    331   typedef CacheState<A> State;
    332   typedef A Arc;
    333   typedef unordered_map<Label, Label> NonTerminalHash;
    334 
    335   typedef T StateTable;
    336   typedef typename T::PrefixId PrefixId;
    337   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
    338 
    339   // constructor for replace class implementation.
    340   // \param fst_tuples array of label/fst tuples, one for each non-terminal
    341   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
    342                  const ReplaceFstOptions<A, T> &opts)
    343       : CacheImpl<A>(opts),
    344         epsilon_on_replace_(opts.epsilon_on_replace),
    345         state_table_(opts.state_table ? opts.state_table :
    346                      new StateTable(fst_tuples, opts.root)) {
    347 
    348     SetType("replace");
    349 
    350     if (fst_tuples.size() > 0) {
    351       SetInputSymbols(fst_tuples[0].second->InputSymbols());
    352       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
    353     }
    354 
    355     bool all_negative = true;  // all nonterminals are negative?
    356     bool dense_range = true;   // all nonterminals are positive
    357                                // and form a dense range containing 1?
    358     for (size_t i = 0; i < fst_tuples.size(); ++i) {
    359       Label nonterminal = fst_tuples[i].first;
    360       if (nonterminal >= 0)
    361         all_negative = false;
    362       if (nonterminal > fst_tuples.size() || nonterminal <= 0)
    363         dense_range = false;
    364     }
    365 
    366     vector<uint64> inprops;
    367     bool all_ilabel_sorted = true;
    368     bool all_olabel_sorted = true;
    369     bool all_non_empty = true;
    370     fst_array_.push_back(0);
    371     for (size_t i = 0; i < fst_tuples.size(); ++i) {
    372       Label label = fst_tuples[i].first;
    373       const Fst<A> *fst = fst_tuples[i].second;
    374       nonterminal_hash_[label] = fst_array_.size();
    375       nonterminal_set_.insert(label);
    376       fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
    377       if (fst->Start() == kNoStateId)
    378         all_non_empty = false;
    379       if(!fst->Properties(kILabelSorted, false))
    380         all_ilabel_sorted = false;
    381       if(!fst->Properties(kOLabelSorted, false))
    382         all_olabel_sorted = false;
    383       inprops.push_back(fst->Properties(kCopyProperties, false));
    384       if (i) {
    385         if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
    386           FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
    387                      << " does not match input symbols of base Fst (0'th fst)";
    388           SetProperties(kError, kError);
    389         }
    390         if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
    391           FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
    392                      << " does not match output symbols of base Fst "
    393                      << "(0'th fst)";
    394           SetProperties(kError, kError);
    395         }
    396       }
    397     }
    398     Label nonterminal = nonterminal_hash_[opts.root];
    399     if ((nonterminal == 0) && (fst_array_.size() > 1)) {
    400       FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
    401                  << opts.root << "' in the input tuple vector";
    402       SetProperties(kError, kError);
    403     }
    404     root_ = (nonterminal > 0) ? nonterminal : 1;
    405 
    406     SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
    407                                     all_non_empty));
    408     // We assume that all terminals are positive.  The resulting
    409     // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
    410     // kILabelSorted and one of the 3 following conditions is satisfied:
    411     //  1. 'epsilon_on_replace' is false, or
    412     //  2. all non-terminals are negative, or
    413     //  3. all non-terninals are positive and form a dense range containing 1.
    414     if (all_ilabel_sorted &&
    415         (!epsilon_on_replace_ || all_negative || dense_range))
    416       SetProperties(kILabelSorted, kILabelSorted);
    417     // Similarly, the resulting ReplaceFst is known to be
    418     // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
    419     // the 2 following conditions is satisfied:
    420     //  1. all non-terminals are negative, or
    421     //  2. all non-terninals are positive and form a dense range containing 1.
    422     if (all_olabel_sorted && (all_negative || dense_range))
    423       SetProperties(kOLabelSorted, kOLabelSorted);
    424 
    425     // Enable optional caching as long as sorted and all non empty.
    426     if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
    427       always_cache_ = false;
    428     else
    429       always_cache_ = true;
    430     VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
    431             << (always_cache_ ? "true" : "false");
    432   }
    433 
    434   ReplaceFstImpl(const ReplaceFstImpl& impl)
    435       : CacheImpl<A>(impl),
    436         epsilon_on_replace_(impl.epsilon_on_replace_),
    437         always_cache_(impl.always_cache_),
    438         state_table_(new StateTable(*(impl.state_table_))),
    439         nonterminal_set_(impl.nonterminal_set_),
    440         nonterminal_hash_(impl.nonterminal_hash_),
    441         root_(impl.root_) {
    442     SetType("replace");
    443     SetProperties(impl.Properties(), kCopyProperties);
    444     SetInputSymbols(impl.InputSymbols());
    445     SetOutputSymbols(impl.OutputSymbols());
    446     fst_array_.reserve(impl.fst_array_.size());
    447     fst_array_.push_back(0);
    448     for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
    449       fst_array_.push_back(impl.fst_array_[i]->Copy(true));
    450     }
    451   }
    452 
    453   ~ReplaceFstImpl() {
    454     VLOG(2) << "~ReplaceFstImpl: gc = "
    455             << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
    456             << ", gc_size = " << CacheImpl<A>::GetCacheSize()
    457             << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
    458 
    459     delete state_table_;
    460     for (size_t i = 1; i < fst_array_.size(); ++i) {
    461       delete fst_array_[i];
    462     }
    463   }
    464 
    465   // Computes the dependency graph of the replace class and returns
    466   // true if the dependencies are cyclic. Cyclic dependencies will result
    467   // in an un-expandable replace fst.
    468   bool CyclicDependencies() const {
    469     ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
    470     return replace_util.CyclicDependencies();
    471   }
    472 
    473   // Return or compute start state of replace fst
    474   StateId Start() {
    475     if (!HasStart()) {
    476       if (fst_array_.size() == 1) {      // no fsts defined for replace
    477         SetStart(kNoStateId);
    478         return kNoStateId;
    479       } else {
    480         const Fst<A>* fst = fst_array_[root_];
    481         StateId fst_start = fst->Start();
    482         if (fst_start == kNoStateId)  // root Fst is empty
    483           return kNoStateId;
    484 
    485         PrefixId prefix = GetPrefixId(StackPrefix());
    486         StateId start = state_table_->FindState(
    487             StateTuple(prefix, root_, fst_start));
    488         SetStart(start);
    489         return start;
    490       }
    491     } else {
    492       return CacheImpl<A>::Start();
    493     }
    494   }
    495 
    496   // return final weight of state (kInfWeight means state is not final)
    497   Weight Final(StateId s) {
    498     if (!HasFinal(s)) {
    499       const StateTuple& tuple  = state_table_->Tuple(s);
    500       const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    501       const Fst<A>* fst = fst_array_[tuple.fst_id];
    502       StateId fst_state = tuple.fst_state;
    503 
    504       if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
    505         SetFinal(s, fst->Final(fst_state));
    506       else
    507         SetFinal(s, Weight::Zero());
    508     }
    509     return CacheImpl<A>::Final(s);
    510   }
    511 
    512   size_t NumArcs(StateId s) {
    513     if (HasArcs(s)) {  // If state cached, use the cached value.
    514       return CacheImpl<A>::NumArcs(s);
    515     } else if (always_cache_) {  // If always caching, expand and cache state.
    516       Expand(s);
    517       return CacheImpl<A>::NumArcs(s);
    518     } else {  // Otherwise compute the number of arcs without expanding.
    519       StateTuple tuple  = state_table_->Tuple(s);
    520       if (tuple.fst_state == kNoStateId)
    521         return 0;
    522 
    523       const Fst<A>* fst = fst_array_[tuple.fst_id];
    524       size_t num_arcs = fst->NumArcs(tuple.fst_state);
    525       if (ComputeFinalArc(tuple, 0))
    526         num_arcs++;
    527 
    528       return num_arcs;
    529     }
    530   }
    531 
    532   // Returns whether a given label is a non terminal
    533   bool IsNonTerminal(Label l) const {
    534     // TODO(allauzen): be smarter and take advantage of
    535     // all_dense or all_negative.
    536     // Use also in ComputeArc, this would require changes to replace
    537     // so that recursing into an empty fst lead to a non co-accessible
    538     // state instead of deleting the arc as done currently.
    539     // Current use correct, since i/olabel sorted iff all_non_empty.
    540     typename NonTerminalHash::const_iterator it =
    541         nonterminal_hash_.find(l);
    542     return it != nonterminal_hash_.end();
    543   }
    544 
    545   size_t NumInputEpsilons(StateId s) {
    546     if (HasArcs(s)) {
    547       // If state cached, use the cached value.
    548       return CacheImpl<A>::NumInputEpsilons(s);
    549     } else if (always_cache_ || !Properties(kILabelSorted)) {
    550       // If always caching or if the number of input epsilons is too expensive
    551       // to compute without caching (i.e. not ilabel sorted),
    552       // then expand and cache state.
    553       Expand(s);
    554       return CacheImpl<A>::NumInputEpsilons(s);
    555     } else {
    556       // Otherwise, compute the number of input epsilons without caching.
    557       StateTuple tuple  = state_table_->Tuple(s);
    558       if (tuple.fst_state == kNoStateId)
    559         return 0;
    560       const Fst<A>* fst = fst_array_[tuple.fst_id];
    561       size_t num  = 0;
    562       if (!epsilon_on_replace_) {
    563         // If epsilon_on_replace is false, all input epsilon arcs
    564         // are also input epsilons arcs in the underlying machine.
    565         fst->NumInputEpsilons(tuple.fst_state);
    566       } else {
    567         // Otherwise, one need to consider that all non-terminal arcs
    568         // in the underlying machine also become input epsilon arc.
    569         ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
    570         for (; !aiter.Done() &&
    571                  ((aiter.Value().ilabel == 0) ||
    572                   IsNonTerminal(aiter.Value().olabel));
    573              aiter.Next())
    574           ++num;
    575       }
    576       if (ComputeFinalArc(tuple, 0))
    577         num++;
    578       return num;
    579     }
    580   }
    581 
    582   size_t NumOutputEpsilons(StateId s) {
    583     if (HasArcs(s)) {
    584       // If state cached, use the cached value.
    585       return CacheImpl<A>::NumOutputEpsilons(s);
    586     } else if(always_cache_ || !Properties(kOLabelSorted)) {
    587       // If always caching or if the number of output epsilons is too expensive
    588       // to compute without caching (i.e. not olabel sorted),
    589       // then expand and cache state.
    590       Expand(s);
    591       return CacheImpl<A>::NumOutputEpsilons(s);
    592     } else {
    593       // Otherwise, compute the number of output epsilons without caching.
    594       StateTuple tuple  = state_table_->Tuple(s);
    595       if (tuple.fst_state == kNoStateId)
    596         return 0;
    597       const Fst<A>* fst = fst_array_[tuple.fst_id];
    598       size_t num  = 0;
    599       ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
    600       for (; !aiter.Done() &&
    601                ((aiter.Value().olabel == 0) ||
    602                 IsNonTerminal(aiter.Value().olabel));
    603            aiter.Next())
    604         ++num;
    605       if (ComputeFinalArc(tuple, 0))
    606         num++;
    607       return num;
    608     }
    609   }
    610 
    611   uint64 Properties() const { return Properties(kFstProperties); }
    612 
    613   // Set error if found; return FST impl properties.
    614   uint64 Properties(uint64 mask) const {
    615     if (mask & kError) {
    616       for (size_t i = 1; i < fst_array_.size(); ++i) {
    617         if (fst_array_[i]->Properties(kError, false))
    618           SetProperties(kError, kError);
    619       }
    620     }
    621     return FstImpl<Arc>::Properties(mask);
    622   }
    623 
    624   // return the base arc iterator, if arcs have not been computed yet,
    625   // extend/recurse for new arcs.
    626   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    627     if (!HasArcs(s))
    628       Expand(s);
    629     CacheImpl<A>::InitArcIterator(s, data);
    630     // TODO(allauzen): Set behaviour of generic iterator
    631     // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
    632     // relies on current behaviour.
    633   }
    634 
    635 
    636   // Extend current state (walk arcs one level deep)
    637   void Expand(StateId s) {
    638     StateTuple tuple = state_table_->Tuple(s);
    639 
    640     // If local fst is empty
    641     if (tuple.fst_state == kNoStateId) {
    642       SetArcs(s);
    643       return;
    644     }
    645 
    646     ArcIterator< Fst<A> > aiter(
    647         *(fst_array_[tuple.fst_id]), tuple.fst_state);
    648     Arc arc;
    649 
    650     // Create a final arc when needed
    651     if (ComputeFinalArc(tuple, &arc))
    652       PushArc(s, arc);
    653 
    654     // Expand all arcs leaving the state
    655     for (;!aiter.Done(); aiter.Next()) {
    656       if (ComputeArc(tuple, aiter.Value(), &arc))
    657         PushArc(s, arc);
    658     }
    659 
    660     SetArcs(s);
    661   }
    662 
    663   void Expand(StateId s, const StateTuple &tuple,
    664               const ArcIteratorData<A> &data) {
    665      // If local fst is empty
    666     if (tuple.fst_state == kNoStateId) {
    667       SetArcs(s);
    668       return;
    669     }
    670 
    671     ArcIterator< Fst<A> > aiter(data);
    672     Arc arc;
    673 
    674     // Create a final arc when needed
    675     if (ComputeFinalArc(tuple, &arc))
    676       AddArc(s, arc);
    677 
    678     // Expand all arcs leaving the state
    679     for (; !aiter.Done(); aiter.Next()) {
    680       if (ComputeArc(tuple, aiter.Value(), &arc))
    681         AddArc(s, arc);
    682     }
    683 
    684     SetArcs(s);
    685   }
    686 
    687   // If arcp == 0, only returns if a final arc is required, does not
    688   // actually compute it.
    689   bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
    690                        uint32 flags = kArcValueFlags) {
    691     const Fst<A>* fst = fst_array_[tuple.fst_id];
    692     StateId fst_state = tuple.fst_state;
    693     if (fst_state == kNoStateId)
    694       return false;
    695 
    696    // if state is final, pop up stack
    697     const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
    698     if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
    699       if (arcp) {
    700         arcp->ilabel = 0;
    701         arcp->olabel = 0;
    702         if (flags & kArcNextStateValue) {
    703           PrefixId prefix_id = PopPrefix(stack);
    704           const PrefixTuple& top = stack.Top();
    705           arcp->nextstate = state_table_->FindState(
    706               StateTuple(prefix_id, top.fst_id, top.nextstate));
    707         }
    708         if (flags & kArcWeightValue)
    709           arcp->weight = fst->Final(fst_state);
    710       }
    711       return true;
    712     } else {
    713       return false;
    714     }
    715   }
    716 
    717   // Compute the arc in the replace fst corresponding to a given
    718   // in the underlying machine. Returns false if the underlying arc
    719   // corresponds to no arc in the replace.
    720   bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
    721                   uint32 flags = kArcValueFlags) {
    722     if (!epsilon_on_replace_ &&
    723         (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
    724       *arcp = arc;
    725       return true;
    726     }
    727 
    728     if (arc.olabel == 0) {  // expand local fst
    729       StateId nextstate = flags & kArcNextStateValue
    730           ? state_table_->FindState(
    731               StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
    732           : kNoStateId;
    733       *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
    734     } else {
    735       // check for non terminal
    736       typename NonTerminalHash::const_iterator it =
    737           nonterminal_hash_.find(arc.olabel);
    738       if (it != nonterminal_hash_.end()) {  // recurse into non terminal
    739         Label nonterminal = it->second;
    740         const Fst<A>* nt_fst = fst_array_[nonterminal];
    741         PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
    742                                         tuple.fst_id, arc.nextstate);
    743 
    744         // if start state is valid replace, else arc is implicitly
    745         // deleted
    746         StateId nt_start = nt_fst->Start();
    747         if (nt_start != kNoStateId) {
    748           StateId nt_nextstate =  flags & kArcNextStateValue
    749               ? state_table_->FindState(
    750                   StateTuple(nt_prefix, nonterminal, nt_start))
    751               : kNoStateId;
    752           Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
    753           *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
    754         } else {
    755           return false;
    756         }
    757       } else {
    758         StateId nextstate = flags & kArcNextStateValue
    759             ? state_table_->FindState(
    760                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
    761             : kNoStateId;
    762         *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
    763       }
    764     }
    765     return true;
    766   }
    767 
    768   // Returns the arc iterator flags supported by this Fst.
    769   uint32 ArcIteratorFlags() const {
    770     uint32 flags = kArcValueFlags;
    771     if (!always_cache_)
    772       flags |= kArcNoCache;
    773     return flags;
    774   }
    775 
    776   T* GetStateTable() const {
    777     return state_table_;
    778   }
    779 
    780   const Fst<A>* GetFst(Label fst_id) const {
    781     return fst_array_[fst_id];
    782   }
    783 
    784   bool EpsilonOnReplace() const { return epsilon_on_replace_; }
    785 
    786   // private helper classes
    787  private:
    788   static const size_t kPrime0;
    789 
    790   // \class PrefixTuple
    791   // \brief Tuple of fst_id and destination state (entry in stack prefix)
    792   struct PrefixTuple {
    793     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
    794 
    795     Label   fst_id;
    796     StateId nextstate;
    797   };
    798 
    799   // \class StackPrefix
    800   // \brief Container for stack prefix.
    801   class StackPrefix {
    802    public:
    803     StackPrefix() {}
    804 
    805     // copy constructor
    806     StackPrefix(const StackPrefix& x) :
    807         prefix_(x.prefix_) {
    808     }
    809 
    810     void Push(StateId fst_id, StateId nextstate) {
    811       prefix_.push_back(PrefixTuple(fst_id, nextstate));
    812     }
    813 
    814     void Pop() {
    815       prefix_.pop_back();
    816     }
    817 
    818     const PrefixTuple& Top() const {
    819       return prefix_[prefix_.size()-1];
    820     }
    821 
    822     size_t Depth() const {
    823       return prefix_.size();
    824     }
    825 
    826    public:
    827     vector<PrefixTuple> prefix_;
    828   };
    829 
    830 
    831   // \class StackPrefixEqual
    832   // \brief Compare two stack prefix classes for equality
    833   class StackPrefixEqual {
    834    public:
    835     bool operator()(const StackPrefix& x, const StackPrefix& y) const {
    836       if (x.prefix_.size() != y.prefix_.size()) return false;
    837       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    838         if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
    839            x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
    840       }
    841       return true;
    842     }
    843   };
    844 
    845   //
    846   // \class StackPrefixKey
    847   // \brief Hash function for stack prefix to prefix id
    848   class StackPrefixKey {
    849    public:
    850     size_t operator()(const StackPrefix& x) const {
    851       size_t sum = 0;
    852       for (size_t i = 0; i < x.prefix_.size(); ++i) {
    853         sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
    854       }
    855       return sum;
    856     }
    857   };
    858 
    859   typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
    860   StackPrefixHash;
    861 
    862   // private methods
    863  private:
    864   // hash stack prefix (return unique index into stackprefix array)
    865   PrefixId GetPrefixId(const StackPrefix& prefix) {
    866     typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
    867     if (it == prefix_hash_.end()) {
    868       PrefixId prefix_id = stackprefix_array_.size();
    869       stackprefix_array_.push_back(prefix);
    870       prefix_hash_[prefix] = prefix_id;
    871       return prefix_id;
    872     } else {
    873       return it->second;
    874     }
    875   }
    876 
    877   // prefix id after a stack pop
    878   PrefixId PopPrefix(StackPrefix prefix) {
    879     prefix.Pop();
    880     return GetPrefixId(prefix);
    881   }
    882 
    883   // prefix id after a stack push
    884   PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
    885     prefix.Push(fst_id, nextstate);
    886     return GetPrefixId(prefix);
    887   }
    888 
    889 
    890   // private data
    891  private:
    892   // runtime options
    893   bool epsilon_on_replace_;
    894   bool always_cache_;  // Optionally caching arc iterator disabled when true
    895 
    896   // state table
    897   StateTable *state_table_;
    898 
    899   // cross index of unique stack prefix
    900   // could potentially have one copy of prefix array
    901   StackPrefixHash prefix_hash_;
    902   vector<StackPrefix> stackprefix_array_;
    903 
    904   set<Label> nonterminal_set_;
    905   NonTerminalHash nonterminal_hash_;
    906   vector<const Fst<A>*> fst_array_;
    907   Label root_;
    908 
    909   void operator=(const ReplaceFstImpl<A, T> &);  // disallow
    910 };
    911 
    912 
    913 template <class A, class T>
    914 const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
    915 
    916 //
    917 // \class ReplaceFst
    918 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
    919 // This version is a delayed Fst.
    920 //
    921 // ReplaceFst supports dynamic replacement of arcs in one Fst with
    922 // another Fst. This replacement is recursive.  ReplaceFst can be used
    923 // to support a variety of delayed constructions such as recursive
    924 // transition networks, union, or closure.  It is constructed with an
    925 // array of Fst(s). One Fst represents the root (or topology)
    926 // machine. The root Fst refers to other Fsts by recursively replacing
    927 // arcs labeled as non-terminals with the matching non-terminal
    928 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
    929 // to determine whether the arc is a non-terminal arc or not. A
    930 // non-terminal can be any label that is not a non-zero terminal label
    931 // in the output alphabet.
    932 //
    933 // Note that the constructor uses a vector of pair<>. These correspond
    934 // to the tuple of non-terminal Label and corresponding Fst. For example
    935 // to implement the closure operation we need 2 Fsts. The first root
    936 // Fst is a single Arc on the start State that self loops, it references
    937 // the particular machine for which we are performing the closure operation.
    938 //
    939 // The ReplaceFst class supports an optionally caching arc iterator:
    940 //    ArcIterator< ReplaceFst<A> >
    941 // The ReplaceFst need to be built such that it is known to be ilabel
    942 // or olabel sorted (see usage below).
    943 //
    944 // Observe that Matcher<Fst<A> > will use the optionally caching arc
    945 // iterator when available (Fst is ilabel sorted and matching on the
    946 // input, or Fst is olabel sorted and matching on the output).
    947 // In order to obtain the most efficient behaviour, it is recommended
    948 // to set 'epsilon_on_replace' to false (this means constructing acceptors
    949 // as transducers with epsilons on the input side of nonterminal arcs)
    950 // and matching on the input side.
    951 //
    952 // This class attaches interface to implementation and handles
    953 // reference counting, delegating most methods to ImplToFst.
    954 template <class A, class T = DefaultReplaceStateTable<A> >
    955 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
    956  public:
    957   friend class ArcIterator< ReplaceFst<A, T> >;
    958   friend class StateIterator< ReplaceFst<A, T> >;
    959   friend class ReplaceFstMatcher<A, T>;
    960 
    961   typedef A Arc;
    962   typedef typename A::Label   Label;
    963   typedef typename A::Weight  Weight;
    964   typedef typename A::StateId StateId;
    965   typedef CacheState<A> State;
    966   typedef ReplaceFstImpl<A, T> Impl;
    967 
    968   using ImplToFst<Impl>::Properties;
    969 
    970   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    971              Label root)
    972       : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
    973 
    974   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
    975              const ReplaceFstOptions<A, T> &opts)
    976       : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
    977 
    978   // See Fst<>::Copy() for doc.
    979   ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
    980       : ImplToFst<Impl>(fst, safe) {}
    981 
    982   // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
    983   virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
    984     return new ReplaceFst<A, T>(*this, safe);
    985   }
    986 
    987   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    988 
    989   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    990     GetImpl()->InitArcIterator(s, data);
    991   }
    992 
    993   virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
    994     if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
    995         ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
    996          (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
    997       return new ReplaceFstMatcher<A, T>(*this, match_type);
    998     }
    999     else {
   1000       VLOG(2) << "Not using replace matcher";
   1001       return 0;
   1002     }
   1003   }
   1004 
   1005   bool CyclicDependencies() const {
   1006     return GetImpl()->CyclicDependencies();
   1007   }
   1008 
   1009  private:
   1010   // Makes visible to friends.
   1011   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
   1012 
   1013   void operator=(const ReplaceFst<A> &fst);  // disallow
   1014 };
   1015 
   1016 
   1017 // Specialization for ReplaceFst.
   1018 template<class A, class T>
   1019 class StateIterator< ReplaceFst<A, T> >
   1020     : public CacheStateIterator< ReplaceFst<A, T> > {
   1021  public:
   1022   explicit StateIterator(const ReplaceFst<A, T> &fst)
   1023       : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
   1024 
   1025  private:
   1026   DISALLOW_COPY_AND_ASSIGN(StateIterator);
   1027 };
   1028 
   1029 
   1030 // Specialization for ReplaceFst.
   1031 // Implements optional caching. It can be used as follows:
   1032 //
   1033 //   ReplaceFst<A> replace;
   1034 //   ArcIterator< ReplaceFst<A> > aiter(replace, s);
   1035 //   // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
   1036 //   aiter.SetFlags(kArcNoCache, kArcNoCache);
   1037 //   // Use the arc iterator, no arc will be cached, no state will be expanded.
   1038 //   // The varied 'kArcValueFlags' can be used to decide which part
   1039 //   // of arc values needs to be computed.
   1040 //   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
   1041 //   // Only want the ilabel for this arc
   1042 //   aiter.Value();  // Does not compute the destination state.
   1043 //   aiter.Next();
   1044 //   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
   1045 //   // Want both ilabel and nextstate for that arc
   1046 //   aiter.Value();  // Does compute the destination state and inserts it
   1047 //                   // in the replace state table.
   1048 //   // No Arc has been cached at that point.
   1049 //
   1050 template <class A, class T>
   1051 class ArcIterator< ReplaceFst<A, T> > {
   1052  public:
   1053   typedef A Arc;
   1054   typedef typename A::StateId StateId;
   1055 
   1056   ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
   1057       : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
   1058         data_flags_(0), final_flags_(0) {
   1059     cache_data_.ref_count = 0;
   1060     local_data_.ref_count = 0;
   1061 
   1062     // If FST does not support optional caching, force caching.
   1063     if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
   1064        !(fst_.GetImpl()->HasArcs(state_)))
   1065        fst_.GetImpl()->Expand(state_);
   1066 
   1067     // If state is already cached, use cached arcs array.
   1068     if (fst_.GetImpl()->HasArcs(state_)) {
   1069       (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
   1070                                                                &cache_data_);
   1071       num_arcs_ = cache_data_.narcs;
   1072       arcs_ = cache_data_.arcs;      // 'arcs_' is a ptr to the cached arcs.
   1073       data_flags_ = kArcValueFlags;  // All the arc member values are valid.
   1074     } else {  // Otherwise delay decision until Value() is called.
   1075       tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
   1076       if (tuple_.fst_state == kNoStateId) {
   1077         num_arcs_ = 0;
   1078       } else {
   1079         // The decision to cache or not to cache has been defered
   1080         // until Value() or SetFlags() is called. However, the arc
   1081         // iterator is set up now to be ready for non-caching in order
   1082         // to keep the Value() method simple and efficient.
   1083         const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
   1084         fst->InitArcIterator(tuple_.fst_state, &local_data_);
   1085         // 'arcs_' is a pointer to the arcs in the underlying machine.
   1086         arcs_ = local_data_.arcs;
   1087         // Compute the final arc (but not its destination state)
   1088         // if a final arc is required.
   1089         bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
   1090             tuple_,
   1091             &final_arc_,
   1092             kArcValueFlags & ~kArcNextStateValue);
   1093         // Set the arc value flags that hold for 'final_arc_'.
   1094         final_flags_ = kArcValueFlags & ~kArcNextStateValue;
   1095         // Compute the number of arcs.
   1096         num_arcs_ = local_data_.narcs;
   1097         if (has_final_arc)
   1098           ++num_arcs_;
   1099         // Set the offset between the underlying arc positions and
   1100         // the positions in the arc iterator.
   1101         offset_ = num_arcs_ - local_data_.narcs;
   1102         // Defers the decision to cache or not until Value() or
   1103         // SetFlags() is called.
   1104         data_flags_ = 0;
   1105       }
   1106     }
   1107   }
   1108 
   1109   ~ArcIterator() {
   1110     if (cache_data_.ref_count)
   1111       --(*cache_data_.ref_count);
   1112     if (local_data_.ref_count)
   1113       --(*local_data_.ref_count);
   1114   }
   1115 
   1116   void ExpandAndCache() const   {
   1117     // TODO(allauzen): revisit this
   1118     // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
   1119     // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
   1120     //                                               &cache_data_);
   1121     //
   1122     fst_.InitArcIterator(state_, &cache_data_);  // Expand and cache state.
   1123     arcs_ = cache_data_.arcs;  // 'arcs_' is a pointer to the cached arcs.
   1124     data_flags_ = kArcValueFlags;  // All the arc member values are valid.
   1125     offset_ = 0;  // No offset
   1126 
   1127   }
   1128 
   1129   void Init() {
   1130     if (flags_ & kArcNoCache) {  // If caching is disabled
   1131       // 'arcs_' is a pointer to the arcs in the underlying machine.
   1132       arcs_ = local_data_.arcs;
   1133       // Set the arcs value flags that hold for 'arcs_'.
   1134       data_flags_ = kArcWeightValue;
   1135       if (!fst_.GetImpl()->EpsilonOnReplace())
   1136           data_flags_ |= kArcILabelValue;
   1137       // Set the offset between the underlying arc positions and
   1138       // the positions in the arc iterator.
   1139       offset_ = num_arcs_ - local_data_.narcs;
   1140     } else {  // Otherwise, expand and cache
   1141       ExpandAndCache();
   1142     }
   1143   }
   1144 
   1145   bool Done() const { return pos_ >= num_arcs_; }
   1146 
   1147   const A& Value() const {
   1148     // If 'data_flags_' was set to 0, non-caching was not requested
   1149     if (!data_flags_) {
   1150       // TODO(allauzen): revisit this.
   1151       if (flags_ & kArcNoCache) {
   1152         // Should never happen.
   1153         FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
   1154       }
   1155       ExpandAndCache();  // Expand and cache.
   1156     }
   1157 
   1158     if (pos_ - offset_ >= 0) {  // The requested arc is not the 'final' arc.
   1159       const A& arc = arcs_[pos_ - offset_];
   1160       if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
   1161         // If the value flags for 'arc' match the recquired value flags
   1162         // then return 'arc'.
   1163         return arc;
   1164       } else {
   1165         // Otherwise, compute the corresponding arc on-the-fly.
   1166         fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
   1167         return arc_;
   1168       }
   1169     } else {  // The requested arc is the 'final' arc.
   1170       if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
   1171         // If the arc value flags that hold for the final arc
   1172         // do not match the requested value flags, then
   1173         // 'final_arc_' needs to be updated.
   1174         fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
   1175                                     flags_ & kArcValueFlags);
   1176         final_flags_ = flags_ & kArcValueFlags;
   1177       }
   1178       return final_arc_;
   1179     }
   1180   }
   1181 
   1182   void Next() { ++pos_; }
   1183 
   1184   size_t Position() const { return pos_; }
   1185 
   1186   void Reset() { pos_ = 0;  }
   1187 
   1188   void Seek(size_t pos) { pos_ = pos; }
   1189 
   1190   uint32 Flags() const { return flags_; }
   1191 
   1192   void SetFlags(uint32 f, uint32 mask) {
   1193     // Update the flags taking into account what flags are supported
   1194     // by the Fst.
   1195     flags_ &= ~mask;
   1196     flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
   1197     // If non-caching is not requested (and caching has not already
   1198     // been performed), then flush 'data_flags_' to request caching
   1199     // during the next call to Value().
   1200     if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
   1201       if (!fst_.GetImpl()->HasArcs(state_))
   1202          data_flags_ = 0;
   1203     }
   1204     // If 'data_flags_' has been flushed but non-caching is requested
   1205     // before calling Value(), then set up the iterator for non-caching.
   1206     if ((f & kArcNoCache) && (!data_flags_))
   1207       Init();
   1208   }
   1209 
   1210  private:
   1211   const ReplaceFst<A, T> &fst_;           // Reference to the FST
   1212   StateId state_;                         // State in the FST
   1213   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
   1214 
   1215   ssize_t pos_;             // Current position
   1216   mutable ssize_t offset_;  // Offset between position in iterator and in arcs_
   1217   ssize_t num_arcs_;        // Number of arcs at state_
   1218   uint32 flags_;            // Behavorial flags for the arc iterator
   1219   mutable Arc arc_;         // Memory to temporarily store computed arcs
   1220 
   1221   mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache
   1222   mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local fst
   1223 
   1224   mutable const A* arcs_;       // Array of arcs
   1225   mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_
   1226   mutable Arc final_arc_;       // Final arc (when required)
   1227   mutable uint32 final_flags_;  // Arc value flags valid for final_arc_
   1228 
   1229   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
   1230 };
   1231 
   1232 
   1233 template <class A, class T>
   1234 class ReplaceFstMatcher : public MatcherBase<A> {
   1235  public:
   1236   typedef A Arc;
   1237   typedef typename A::StateId StateId;
   1238   typedef typename A::Label Label;
   1239   typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
   1240 
   1241   ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
   1242       : fst_(fst),
   1243         impl_(fst_.GetImpl()),
   1244         s_(fst::kNoStateId),
   1245         match_type_(match_type),
   1246         current_loop_(false),
   1247         final_arc_(false),
   1248         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
   1249     if (match_type_ == fst::MATCH_OUTPUT)
   1250       swap(loop_.ilabel, loop_.olabel);
   1251     InitMatchers();
   1252   }
   1253 
   1254   ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
   1255       : fst_(matcher.fst_),
   1256         impl_(fst_.GetImpl()),
   1257         s_(fst::kNoStateId),
   1258         match_type_(matcher.match_type_),
   1259         current_loop_(false),
   1260         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
   1261     if (match_type_ == fst::MATCH_OUTPUT)
   1262       swap(loop_.ilabel, loop_.olabel);
   1263     InitMatchers();
   1264   }
   1265 
   1266   // Create a local matcher for each component Fst of replace.
   1267   // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
   1268   // is used to match each non-terminal arc, since these non-terminal
   1269   // turn into epsilons on recursion.
   1270   void InitMatchers() {
   1271     const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
   1272     matcher_.resize(fst_array.size(), 0);
   1273     for (size_t i = 0; i < fst_array.size(); ++i) {
   1274       if (fst_array[i]) {
   1275         matcher_[i] =
   1276             new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
   1277 
   1278         typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
   1279         for (; it != impl_->nonterminal_set_.end(); ++it) {
   1280           matcher_[i]->AddMultiEpsLabel(*it);
   1281         }
   1282       }
   1283     }
   1284   }
   1285 
   1286   virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
   1287     return new ReplaceFstMatcher<A, T>(*this, safe);
   1288   }
   1289 
   1290   virtual ~ReplaceFstMatcher() {
   1291     for (size_t i = 0; i < matcher_.size(); ++i)
   1292       delete matcher_[i];
   1293   }
   1294 
   1295   virtual MatchType Type(bool test) const {
   1296     if (match_type_ == MATCH_NONE)
   1297       return match_type_;
   1298 
   1299     uint64 true_prop =  match_type_ == MATCH_INPUT ?
   1300         kILabelSorted : kOLabelSorted;
   1301     uint64 false_prop = match_type_ == MATCH_INPUT ?
   1302         kNotILabelSorted : kNotOLabelSorted;
   1303     uint64 props = fst_.Properties(true_prop | false_prop, test);
   1304 
   1305     if (props & true_prop)
   1306       return match_type_;
   1307     else if (props & false_prop)
   1308       return MATCH_NONE;
   1309     else
   1310       return MATCH_UNKNOWN;
   1311   }
   1312 
   1313   virtual const Fst<A> &GetFst() const {
   1314     return fst_;
   1315   }
   1316 
   1317   virtual uint64 Properties(uint64 props) const {
   1318     return props;
   1319   }
   1320 
   1321  private:
   1322   // Set the sate from which our matching happens.
   1323   virtual void SetState_(StateId s) {
   1324     if (s_ == s) return;
   1325 
   1326     s_ = s;
   1327     tuple_ = impl_->GetStateTable()->Tuple(s_);
   1328     if (tuple_.fst_state == kNoStateId) {
   1329       done_ = true;
   1330       return;
   1331     }
   1332     // Get current matcher. Used for non epsilon matching
   1333     current_matcher_ = matcher_[tuple_.fst_id];
   1334     current_matcher_->SetState(tuple_.fst_state);
   1335     loop_.nextstate = s_;
   1336 
   1337     final_arc_ = false;
   1338   }
   1339 
   1340   // Search for label, from previous set state. If label == 0, first
   1341   // hallucinate and epsilon loop, else use the underlying matcher to
   1342   // search for the label or epsilons.
   1343   // - Note since the ReplaceFST recursion on non-terminal arcs causes
   1344   //   epsilon transitions to be created we use the MultiEpsilonMatcher
   1345   //   to search for possible matches of non terminals.
   1346   // - If the component Fst reaches a final state we also need to add
   1347   //   the exiting final arc.
   1348   virtual bool Find_(Label label) {
   1349     bool found = false;
   1350     label_ = label;
   1351     if (label_ == 0 || label_ == kNoLabel) {
   1352       // Compute loop directly, saving Replace::ComputeArc
   1353       if (label_ == 0) {
   1354         current_loop_ = true;
   1355         found = true;
   1356       }
   1357       // Search for matching multi epsilons
   1358       final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
   1359       found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
   1360     } else {
   1361       // Search on sub machine directly using sub machine matcher.
   1362       found = current_matcher_->Find(label_);
   1363     }
   1364     return found;
   1365   }
   1366 
   1367   virtual bool Done_() const {
   1368     return !current_loop_ && !final_arc_ && current_matcher_->Done();
   1369   }
   1370 
   1371   virtual const Arc& Value_() const {
   1372     if (current_loop_) {
   1373       return loop_;
   1374     }
   1375     if (final_arc_) {
   1376       impl_->ComputeFinalArc(tuple_, &arc_);
   1377       return arc_;
   1378     }
   1379     const Arc& component_arc = current_matcher_->Value();
   1380     impl_->ComputeArc(tuple_, component_arc, &arc_);
   1381     return arc_;
   1382   }
   1383 
   1384   virtual void Next_() {
   1385     if (current_loop_) {
   1386       current_loop_ = false;
   1387       return;
   1388     }
   1389     if (final_arc_) {
   1390       final_arc_ = false;
   1391       return;
   1392     }
   1393     current_matcher_->Next();
   1394   }
   1395 
   1396   const ReplaceFst<A, T>& fst_;
   1397   ReplaceFstImpl<A, T> *impl_;
   1398   LocalMatcher* current_matcher_;
   1399   vector<LocalMatcher*> matcher_;
   1400 
   1401   StateId s_;                        // Current state
   1402   Label label_;                      // Current label
   1403 
   1404   MatchType match_type_;             // Supplied by caller
   1405   mutable bool done_;
   1406   mutable bool current_loop_;        // Current arc is the implicit loop
   1407   mutable bool final_arc_;           // Current arc for exiting recursion
   1408   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
   1409   mutable Arc arc_;
   1410   Arc loop_;
   1411 };
   1412 
   1413 template <class A, class T> inline
   1414 void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
   1415   data->base = new StateIterator< ReplaceFst<A, T> >(*this);
   1416 }
   1417 
   1418 typedef ReplaceFst<StdArc> StdReplaceFst;
   1419 
   1420 
   1421 // // Recursivively replaces arcs in the root Fst with other Fsts.
   1422 // This version writes the result of replacement to an output MutableFst.
   1423 //
   1424 // Replace supports replacement of arcs in one Fst with another
   1425 // Fst. This replacement is recursive.  Replace takes an array of
   1426 // Fst(s). One Fst represents the root (or topology) machine. The root
   1427 // Fst refers to other Fsts by recursively replacing arcs labeled as
   1428 // non-terminals with the matching non-terminal Fst. Currently Replace
   1429 // uses the output symbols of the arcs to determine whether the arc is
   1430 // a non-terminal arc or not. A non-terminal can be any label that is
   1431 // not a non-zero terminal label in the output alphabet.  Note that
   1432 // input argument is a vector of pair<>. These correspond to the tuple
   1433 // of non-terminal Label and corresponding Fst.
   1434 template<class Arc>
   1435 void Replace(const vector<pair<typename Arc::Label,
   1436              const Fst<Arc>* > >& ifst_array,
   1437              MutableFst<Arc> *ofst, typename Arc::Label root,
   1438              bool epsilon_on_replace) {
   1439   ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
   1440   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
   1441   *ofst = ReplaceFst<Arc>(ifst_array, opts);
   1442 }
   1443 
   1444 template<class Arc>
   1445 void Replace(const vector<pair<typename Arc::Label,
   1446              const Fst<Arc>* > >& ifst_array,
   1447              MutableFst<Arc> *ofst, typename Arc::Label root) {
   1448   Replace(ifst_array, ofst, root, false);
   1449 }
   1450 
   1451 }  // namespace fst
   1452 
   1453 #endif  // FST_LIB_REPLACE_H__
   1454