Home | History | Annotate | Download | only in fst
      1 // randgen.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes and functions to generate random paths through an FST.
     20 
     21 #ifndef FST_LIB_RANDGEN_H__
     22 #define FST_LIB_RANDGEN_H__
     23 
     24 #include <cmath>
     25 #include <cstdlib>
     26 #include <ctime>
     27 #include <map>
     28 
     29 #include <fst/accumulator.h>
     30 #include <fst/cache.h>
     31 #include <fst/dfs-visit.h>
     32 #include <fst/mutable-fst.h>
     33 
     34 namespace fst {
     35 
     36 //
     37 // ARC SELECTORS - these function objects are used to select a random
     38 // transition to take from an FST's state. They should return a number
     39 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
     40 // transition is selected. If N == NumArcs(), then the final weight at
     41 // that state is selected (i.e., the 'super-final' transition is selected).
     42 // It can be assumed these will not be called unless either there
     43 // are transitions leaving the state and/or the state is final.
     44 //
     45 
     46 // Randomly selects a transition using the uniform distribution.
     47 template <class A>
     48 struct UniformArcSelector {
     49   typedef typename A::StateId StateId;
     50   typedef typename A::Weight Weight;
     51 
     52   UniformArcSelector(int seed = time(0)) { srand(seed); }
     53 
     54   size_t operator()(const Fst<A> &fst, StateId s) const {
     55     double r = rand()/(RAND_MAX + 1.0);
     56     size_t n = fst.NumArcs(s);
     57     if (fst.Final(s) != Weight::Zero())
     58       ++n;
     59     return static_cast<size_t>(r * n);
     60   }
     61 };
     62 
     63 
     64 // Randomly selects a transition w.r.t. the weights treated as negative
     65 // log probabilities after normalizing for the total weight leaving
     66 // the state. Weight::zero transitions are disregarded.
     67 // Assumes Weight::Value() accesses the floating point
     68 // representation of the weight.
     69 template <class A>
     70 class LogProbArcSelector {
     71  public:
     72   typedef typename A::StateId StateId;
     73   typedef typename A::Weight Weight;
     74 
     75   LogProbArcSelector(int seed = time(0)) { srand(seed); }
     76 
     77   size_t operator()(const Fst<A> &fst, StateId s) const {
     78     // Find total weight leaving state
     79     double sum = 0.0;
     80     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
     81          aiter.Next()) {
     82       const A &arc = aiter.Value();
     83       sum += exp(-to_log_weight_(arc.weight).Value());
     84     }
     85     sum += exp(-to_log_weight_(fst.Final(s)).Value());
     86 
     87     double r = rand()/(RAND_MAX + 1.0);
     88     double p = 0.0;
     89     int n = 0;
     90     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
     91          aiter.Next(), ++n) {
     92       const A &arc = aiter.Value();
     93       p += exp(-to_log_weight_(arc.weight).Value());
     94       if (p > r * sum) return n;
     95     }
     96     return n;
     97   }
     98 
     99  private:
    100   WeightConvert<Weight, Log64Weight> to_log_weight_;
    101 };
    102 
    103 // Convenience definitions
    104 typedef LogProbArcSelector<StdArc> StdArcSelector;
    105 typedef LogProbArcSelector<LogArc> LogArcSelector;
    106 
    107 
    108 // Same as LogProbArcSelector but use CacheLogAccumulator to cache
    109 // the cummulative weight computations.
    110 template <class A>
    111 class FastLogProbArcSelector : public LogProbArcSelector<A> {
    112  public:
    113   typedef typename A::StateId StateId;
    114   typedef typename A::Weight Weight;
    115   using LogProbArcSelector<A>::operator();
    116 
    117   FastLogProbArcSelector(int seed = time(0))
    118       : LogProbArcSelector<A>(seed),
    119         seed_(seed) {}
    120 
    121   size_t operator()(const Fst<A> &fst, StateId s,
    122                     CacheLogAccumulator<A> *accumulator) const {
    123     accumulator->SetState(s);
    124     ArcIterator< Fst<A> > aiter(fst, s);
    125     // Find total weight leaving state
    126     double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
    127                                                  fst.NumArcs(s))).Value();
    128     double r = -log(rand()/(RAND_MAX + 1.0));
    129     return accumulator->LowerBound(r + sum, &aiter);
    130   }
    131 
    132   int Seed() const { return seed_; }
    133  private:
    134   int seed_;
    135   WeightConvert<Weight, Log64Weight> to_log_weight_;
    136 };
    137 
    138 // Random path state info maintained by RandGenFst and passed to samplers.
    139 template <typename A>
    140 struct RandState {
    141   typedef typename A::StateId StateId;
    142 
    143   StateId state_id;              // current input FST state
    144   size_t nsamples;               // # of samples to be sampled at this state
    145   size_t length;                 // length of path to this random state
    146   size_t select;                 // previous sample arc selection
    147   const RandState<A> *parent;    // previous random state on this path
    148 
    149   RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
    150       : state_id(s), nsamples(n), length(l), select(k), parent(p) {}
    151 
    152   RandState()
    153       : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
    154 };
    155 
    156 // This class, given an arc selector, samples, with raplacement,
    157 // multiple random transitions from an FST's state. This is a generic
    158 // version with a straight-forward use of the arc selector.
    159 // Specializations may be defined for arc selectors for greater
    160 // efficiency or special behavior.
    161 template <class A, class S>
    162 class ArcSampler {
    163  public:
    164   typedef typename A::StateId StateId;
    165   typedef typename A::Weight Weight;
    166 
    167   // The 'max_length' may be interpreted (including ignored) by a
    168   // sampler as it chooses. This generic version interprets this literally.
    169   ArcSampler(const Fst<A> &fst, const S &arc_selector,
    170              int max_length = INT_MAX)
    171       : fst_(fst),
    172         arc_selector_(arc_selector),
    173         max_length_(max_length) {}
    174 
    175   // Allow updating Fst argument; pass only if changed.
    176   ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
    177       : fst_(fst ? *fst : sampler.fst_),
    178         arc_selector_(sampler.arc_selector_),
    179         max_length_(sampler.max_length_) {
    180     Reset();
    181   }
    182 
    183   // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
    184   // the length of the path to 'rstate'. Returns true if samples were
    185   // collected.  No samples may be collected if either there are no (including
    186   // 'super-final') transitions leaving that state or if the
    187   // 'max_length' has been deemed reached. Use the iterator members to
    188   // read the samples. The samples will be in their original order.
    189   bool Sample(const RandState<A> &rstate) {
    190     sample_map_.clear();
    191     if ((fst_.NumArcs(rstate.state_id) == 0 &&
    192          fst_.Final(rstate.state_id) == Weight::Zero()) ||
    193         rstate.length == max_length_) {
    194       Reset();
    195       return false;
    196     }
    197 
    198     for (size_t i = 0; i < rstate.nsamples; ++i)
    199       ++sample_map_[arc_selector_(fst_, rstate.state_id)];
    200     Reset();
    201     return true;
    202   }
    203 
    204   // More samples?
    205   bool Done() const { return sample_iter_ == sample_map_.end(); }
    206 
    207   // Gets the next sample.
    208   void Next() { ++sample_iter_; }
    209 
    210   // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
    211   // If N < NumArcs(s), then the N-th transition is specified.
    212   // If N == NumArcs(s), then the final weight at that state is
    213   // specified (i.e., the 'super-final' transition is specified).
    214   // For the specified transition, K repetitions have been sampled.
    215   pair<size_t, size_t> Value() const { return *sample_iter_; }
    216 
    217   void Reset() { sample_iter_ = sample_map_.begin(); }
    218 
    219   bool Error() const { return false; }
    220 
    221  private:
    222   const Fst<A> &fst_;
    223   const S &arc_selector_;
    224   int max_length_;
    225 
    226   // Stores (N, K) as described for Value().
    227   map<size_t, size_t> sample_map_;
    228   map<size_t, size_t>::const_iterator sample_iter_;
    229 
    230   // disallow
    231   ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
    232 };
    233 
    234 
    235 // Specialization for FastLogProbArcSelector.
    236 template <class A>
    237 class ArcSampler<A, FastLogProbArcSelector<A> > {
    238  public:
    239   typedef FastLogProbArcSelector<A> S;
    240   typedef typename A::StateId StateId;
    241   typedef typename A::Weight Weight;
    242   typedef CacheLogAccumulator<A> C;
    243 
    244   ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
    245       : fst_(fst),
    246         arc_selector_(arc_selector),
    247         max_length_(max_length),
    248         accumulator_(new C()) {
    249     accumulator_->Init(fst);
    250   }
    251 
    252   ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
    253       : fst_(fst ? *fst : sampler.fst_),
    254         arc_selector_(sampler.arc_selector_),
    255         max_length_(sampler.max_length_) {
    256     if (fst) {
    257       accumulator_ = new C();
    258       accumulator_->Init(*fst);
    259     } else {  // shallow copy
    260       accumulator_ = new C(*sampler.accumulator_);
    261     }
    262   }
    263 
    264   ~ArcSampler() {
    265     delete accumulator_;
    266   }
    267 
    268   bool Sample(const RandState<A> &rstate) {
    269     sample_map_.clear();
    270     if ((fst_.NumArcs(rstate.state_id) == 0 &&
    271          fst_.Final(rstate.state_id) == Weight::Zero()) ||
    272         rstate.length == max_length_) {
    273       Reset();
    274       return false;
    275     }
    276 
    277     for (size_t i = 0; i < rstate.nsamples; ++i)
    278       ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
    279     Reset();
    280     return true;
    281   }
    282 
    283   bool Done() const { return sample_iter_ == sample_map_.end(); }
    284   void Next() { ++sample_iter_; }
    285   pair<size_t, size_t> Value() const { return *sample_iter_; }
    286   void Reset() { sample_iter_ = sample_map_.begin(); }
    287 
    288   bool Error() const { return accumulator_->Error(); }
    289 
    290  private:
    291   const Fst<A> &fst_;
    292   const S &arc_selector_;
    293   int max_length_;
    294 
    295   // Stores (N, K) as described for Value().
    296   map<size_t, size_t> sample_map_;
    297   map<size_t, size_t>::const_iterator sample_iter_;
    298   C *accumulator_;
    299 
    300   // disallow
    301   ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
    302 };
    303 
    304 
    305 // Options for random path generation with RandGenFst. The template argument
    306 // is an arc sampler, typically class 'ArcSampler' above.  Ownership of
    307 // the sampler is taken by RandGenFst.
    308 template <class S>
    309 struct RandGenFstOptions : public CacheOptions {
    310   S *arc_sampler;            // How to sample transitions at a state
    311   size_t npath;              // # of paths to generate
    312   bool weighted;             // Output tree weighted by path count; o.w.
    313                              // output unweighted DAG
    314   bool remove_total_weight;  // Remove total weight when output is weighted.
    315 
    316   RandGenFstOptions(const CacheOptions &copts, S *samp,
    317                     size_t n = 1, bool w = true, bool rw = false)
    318       : CacheOptions(copts),
    319         arc_sampler(samp),
    320         npath(n),
    321         weighted(w),
    322         remove_total_weight(rw) {}
    323 };
    324 
    325 
    326 // Implementation of RandGenFst.
    327 template <class A, class B, class S>
    328 class RandGenFstImpl : public CacheImpl<B> {
    329  public:
    330   using FstImpl<B>::SetType;
    331   using FstImpl<B>::SetProperties;
    332   using FstImpl<B>::SetInputSymbols;
    333   using FstImpl<B>::SetOutputSymbols;
    334 
    335   using CacheBaseImpl< CacheState<B> >::AddArc;
    336   using CacheBaseImpl< CacheState<B> >::HasArcs;
    337   using CacheBaseImpl< CacheState<B> >::HasFinal;
    338   using CacheBaseImpl< CacheState<B> >::HasStart;
    339   using CacheBaseImpl< CacheState<B> >::SetArcs;
    340   using CacheBaseImpl< CacheState<B> >::SetFinal;
    341   using CacheBaseImpl< CacheState<B> >::SetStart;
    342 
    343   typedef B Arc;
    344   typedef typename A::Label Label;
    345   typedef typename A::Weight Weight;
    346   typedef typename A::StateId StateId;
    347 
    348   RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
    349       : CacheImpl<B>(opts),
    350         fst_(fst.Copy()),
    351         arc_sampler_(opts.arc_sampler),
    352         npath_(opts.npath),
    353         weighted_(opts.weighted),
    354         remove_total_weight_(opts.remove_total_weight),
    355         superfinal_(kNoLabel) {
    356     SetType("randgen");
    357 
    358     uint64 props = fst.Properties(kFstProperties, false);
    359     SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
    360 
    361     SetInputSymbols(fst.InputSymbols());
    362     SetOutputSymbols(fst.OutputSymbols());
    363   }
    364 
    365   RandGenFstImpl(const RandGenFstImpl &impl)
    366     : CacheImpl<B>(impl),
    367       fst_(impl.fst_->Copy(true)),
    368       arc_sampler_(new S(*impl.arc_sampler_, fst_)),
    369       npath_(impl.npath_),
    370       weighted_(impl.weighted_),
    371       superfinal_(kNoLabel) {
    372     SetType("randgen");
    373     SetProperties(impl.Properties(), kCopyProperties);
    374     SetInputSymbols(impl.InputSymbols());
    375     SetOutputSymbols(impl.OutputSymbols());
    376   }
    377 
    378   ~RandGenFstImpl() {
    379     for (int i = 0; i < state_table_.size(); ++i)
    380       delete state_table_[i];
    381     delete fst_;
    382     delete arc_sampler_;
    383   }
    384 
    385   StateId Start() {
    386     if (!HasStart()) {
    387       StateId s = fst_->Start();
    388       if (s == kNoStateId)
    389         return kNoStateId;
    390       StateId start = state_table_.size();
    391       SetStart(start);
    392       RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
    393       state_table_.push_back(rstate);
    394     }
    395     return CacheImpl<B>::Start();
    396   }
    397 
    398   Weight Final(StateId s) {
    399     if (!HasFinal(s)) {
    400       Expand(s);
    401     }
    402     return CacheImpl<B>::Final(s);
    403   }
    404 
    405   size_t NumArcs(StateId s) {
    406     if (!HasArcs(s)) {
    407       Expand(s);
    408     }
    409     return CacheImpl<B>::NumArcs(s);
    410   }
    411 
    412   size_t NumInputEpsilons(StateId s) {
    413     if (!HasArcs(s))
    414       Expand(s);
    415     return CacheImpl<B>::NumInputEpsilons(s);
    416   }
    417 
    418   size_t NumOutputEpsilons(StateId s) {
    419     if (!HasArcs(s))
    420       Expand(s);
    421     return CacheImpl<B>::NumOutputEpsilons(s);
    422   }
    423 
    424   uint64 Properties() const { return Properties(kFstProperties); }
    425 
    426   // Set error if found; return FST impl properties.
    427   uint64 Properties(uint64 mask) const {
    428     if ((mask & kError) &&
    429         (fst_->Properties(kError, false) || arc_sampler_->Error())) {
    430       SetProperties(kError, kError);
    431     }
    432     return FstImpl<Arc>::Properties(mask);
    433   }
    434 
    435   void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
    436     if (!HasArcs(s))
    437       Expand(s);
    438     CacheImpl<B>::InitArcIterator(s, data);
    439   }
    440 
    441   // Computes the outgoing transitions from a state, creating new destination
    442   // states as needed.
    443   void Expand(StateId s) {
    444     if (s == superfinal_) {
    445       SetFinal(s, Weight::One());
    446       SetArcs(s);
    447       return;
    448     }
    449 
    450     SetFinal(s, Weight::Zero());
    451     const RandState<A> &rstate = *state_table_[s];
    452     arc_sampler_->Sample(rstate);
    453     ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
    454     size_t narcs = fst_->NumArcs(rstate.state_id);
    455     for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
    456       const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
    457       size_t pos = sample_pair.first;
    458       size_t count = sample_pair.second;
    459       double prob = static_cast<double>(count)/rstate.nsamples;
    460       if (pos < narcs) {  // regular transition
    461         aiter.Seek(sample_pair.first);
    462         const A &aarc = aiter.Value();
    463         Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
    464         B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
    465         AddArc(s, barc);
    466         RandState<A> *nrstate =
    467             new RandState<A>(aarc.nextstate, count, rstate.length + 1,
    468                              pos, &rstate);
    469         state_table_.push_back(nrstate);
    470       } else {            // super-final transition
    471         if (weighted_) {
    472           Weight weight = remove_total_weight_ ?
    473               to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
    474           SetFinal(s, weight);
    475         } else {
    476           if (superfinal_ == kNoLabel) {
    477             superfinal_ = state_table_.size();
    478             RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
    479             state_table_.push_back(nrstate);
    480           }
    481           for (size_t n = 0; n < count; ++n) {
    482             B barc(0, 0, Weight::One(), superfinal_);
    483             AddArc(s, barc);
    484           }
    485         }
    486       }
    487     }
    488     SetArcs(s);
    489   }
    490 
    491  private:
    492   Fst<A> *fst_;
    493   S *arc_sampler_;
    494   size_t npath_;
    495   vector<RandState<A> *> state_table_;
    496   bool weighted_;
    497   bool remove_total_weight_;
    498   StateId superfinal_;
    499   WeightConvert<Log64Weight, Weight> to_weight_;
    500 
    501   void operator=(const RandGenFstImpl<A, B, S> &);  // disallow
    502 };
    503 
    504 
    505 // Fst class to randomly generate paths through an FST; details controlled
    506 // by RandGenOptionsFst. Output format is a tree weighted by the
    507 // path count.
    508 template <class A, class B, class S>
    509 class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
    510  public:
    511   friend class ArcIterator< RandGenFst<A, B, S> >;
    512   friend class StateIterator< RandGenFst<A, B, S> >;
    513   typedef B Arc;
    514   typedef S Sampler;
    515   typedef typename A::Label Label;
    516   typedef typename A::Weight Weight;
    517   typedef typename A::StateId StateId;
    518   typedef CacheState<B> State;
    519   typedef RandGenFstImpl<A, B, S> Impl;
    520 
    521   RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
    522     : ImplToFst<Impl>(new Impl(fst, opts)) {}
    523 
    524   // See Fst<>::Copy() for doc.
    525  RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
    526     : ImplToFst<Impl>(fst, safe) {}
    527 
    528   // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
    529   virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
    530     return new RandGenFst<A, B, S>(*this, safe);
    531   }
    532 
    533   virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
    534 
    535   virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
    536     GetImpl()->InitArcIterator(s, data);
    537   }
    538 
    539  private:
    540   // Makes visible to friends.
    541   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    542 
    543   void operator=(const RandGenFst<A, B, S> &fst);  // Disallow
    544 };
    545 
    546 
    547 
    548 // Specialization for RandGenFst.
    549 template <class A, class B, class S>
    550 class StateIterator< RandGenFst<A, B, S> >
    551     : public CacheStateIterator< RandGenFst<A, B, S> > {
    552  public:
    553   explicit StateIterator(const RandGenFst<A, B, S> &fst)
    554     : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
    555 
    556  private:
    557   DISALLOW_COPY_AND_ASSIGN(StateIterator);
    558 };
    559 
    560 
    561 // Specialization for RandGenFst.
    562 template <class A, class B, class S>
    563 class ArcIterator< RandGenFst<A, B, S> >
    564     : public CacheArcIterator< RandGenFst<A, B, S> > {
    565  public:
    566   typedef typename A::StateId StateId;
    567 
    568   ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
    569       : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
    570     if (!fst.GetImpl()->HasArcs(s))
    571       fst.GetImpl()->Expand(s);
    572   }
    573 
    574  private:
    575   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    576 };
    577 
    578 
    579 template <class A, class B, class S> inline
    580 void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
    581 {
    582   data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
    583 }
    584 
    585 // Options for random path generation.
    586 template <class S>
    587 struct RandGenOptions {
    588   const S &arc_selector;     // How an arc is selected at a state
    589   int max_length;            // Maximum path length
    590   size_t npath;              // # of paths to generate
    591   bool weighted;             // Output is tree weighted by path count; o.w.
    592                              // output unweighted union of paths.
    593   bool remove_total_weight;  // Remove total weight when output is weighted.
    594 
    595   RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
    596                  bool w = false, bool rw = false)
    597       : arc_selector(sel),
    598         max_length(len),
    599         npath(n),
    600         weighted(w),
    601         remove_total_weight(rw) {}
    602 };
    603 
    604 
    605 template <class IArc, class OArc>
    606 class RandGenVisitor {
    607  public:
    608   typedef typename IArc::Weight Weight;
    609   typedef typename IArc::StateId StateId;
    610 
    611   RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
    612 
    613   void InitVisit(const Fst<IArc> &ifst) {
    614     ifst_ = &ifst;
    615 
    616     ofst_->DeleteStates();
    617     ofst_->SetInputSymbols(ifst.InputSymbols());
    618     ofst_->SetOutputSymbols(ifst.OutputSymbols());
    619     if (ifst.Properties(kError, false))
    620       ofst_->SetProperties(kError, kError);
    621     path_.clear();
    622   }
    623 
    624   bool InitState(StateId s, StateId root) { return true; }
    625 
    626   bool TreeArc(StateId s, const IArc &arc) {
    627     if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
    628       path_.push_back(arc);
    629     } else {
    630       OutputPath();
    631     }
    632     return true;
    633   }
    634 
    635   bool BackArc(StateId s, const IArc &arc) {
    636     FSTERROR() << "RandGenVisitor: cyclic input";
    637     ofst_->SetProperties(kError, kError);
    638     return false;
    639   }
    640 
    641   bool ForwardOrCrossArc(StateId s, const IArc &arc) {
    642     OutputPath();
    643     return true;
    644   }
    645 
    646   void FinishState(StateId s, StateId p, const IArc *) {
    647     if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
    648       path_.pop_back();
    649   }
    650 
    651   void FinishVisit() {}
    652 
    653  private:
    654   void OutputPath() {
    655     if (ofst_->Start() == kNoStateId) {
    656       StateId start = ofst_->AddState();
    657       ofst_->SetStart(start);
    658     }
    659 
    660     StateId src = ofst_->Start();
    661     for (size_t i = 0; i < path_.size(); ++i) {
    662       StateId dest = ofst_->AddState();
    663       OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
    664       ofst_->AddArc(src, arc);
    665       src = dest;
    666     }
    667     ofst_->SetFinal(src, Weight::One());
    668   }
    669 
    670   const Fst<IArc> *ifst_;
    671   MutableFst<OArc> *ofst_;
    672   vector<OArc> path_;
    673 
    674   DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
    675 };
    676 
    677 
    678 // Randomly generate paths through an FST; details controlled by
    679 // RandGenOptions.
    680 template<class IArc, class OArc, class Selector>
    681 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
    682              const RandGenOptions<Selector> &opts) {
    683   typedef ArcSampler<IArc, Selector> Sampler;
    684   typedef RandGenFst<IArc, OArc, Sampler> RandFst;
    685   typedef typename OArc::StateId StateId;
    686   typedef typename OArc::Weight Weight;
    687 
    688   Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
    689   RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
    690                                    opts.npath, opts.weighted,
    691                                    opts.remove_total_weight);
    692   RandFst rfst(ifst, fopts);
    693   if (opts.weighted) {
    694     *ofst = rfst;
    695   } else {
    696     RandGenVisitor<IArc, OArc> rand_visitor(ofst);
    697     DfsVisit(rfst, &rand_visitor);
    698   }
    699 }
    700 
    701 // Randomly generate a path through an FST with the uniform distribution
    702 // over the transitions.
    703 template<class IArc, class OArc>
    704 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
    705   UniformArcSelector<IArc> uniform_selector;
    706   RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
    707   RandGen(ifst, ofst, opts);
    708 }
    709 
    710 }  // namespace fst
    711 
    712 #endif  // FST_LIB_RANDGEN_H__
    713