Home | History | Annotate | Download | only in lib
      1 // randgen.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Function to generate random paths through an FST.
     18 
     19 #ifndef FST_LIB_RANDGEN_H__
     20 #define FST_LIB_RANDGEN_H__
     21 
     22 #include <cmath>
     23 #include <cstdlib>
     24 #include <ctime>
     25 
     26 #include "fst/lib/mutable-fst.h"
     27 
     28 namespace fst {
     29 
     30 //
     31 // ARC SELECTORS - these function objects are used to select a random
     32 // transition to take from an FST's state. They should return a number
     33 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
     34 // transition is selected. If N == NumArcs(), then the final weight at
     35 // that state is selected (i.e., the 'super-final' transition is selected).
     36 // It can be assumed these will not be called unless either there
     37 // are transitions leaving the state and/or the state is final.
     38 //
     39 
     40 // Randomly selects a transition using the uniform distribution.
     41 template <class A>
     42 struct UniformArcSelector {
     43   typedef typename A::StateId StateId;
     44   typedef typename A::Weight Weight;
     45 
     46   UniformArcSelector(int seed = time(0)) { srand(seed); }
     47 
     48   size_t operator()(const Fst<A> &fst, StateId s) const {
     49     double r = rand()/(RAND_MAX + 1.0);
     50     size_t n = fst.NumArcs(s);
     51     if (fst.Final(s) != Weight::Zero())
     52       ++n;
     53     return static_cast<size_t>(r * n);
     54   }
     55 };
     56 
     57 // Randomly selects a transition w.r.t. the weights treated as negative
     58 // log probabilities after normalizing for the total weight leaving
     59 // the state). Weight::zero transitions are disregarded.
     60 // Assumes Weight::Value() accesses the floating point
     61 // representation of the weight.
     62 template <class A>
     63 struct LogProbArcSelector {
     64   typedef typename A::StateId StateId;
     65   typedef typename A::Weight Weight;
     66 
     67   LogProbArcSelector(int seed = time(0)) { srand(seed); }
     68 
     69   size_t operator()(const Fst<A> &fst, StateId s) const {
     70     // Find total weight leaving state
     71     double sum = 0.0;
     72     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
     73          aiter.Next()) {
     74       const A &arc = aiter.Value();
     75       sum += exp(-arc.weight.Value());
     76     }
     77     sum += exp(-fst.Final(s).Value());
     78 
     79     double r = rand()/(RAND_MAX + 1.0);
     80     double p = 0.0;
     81     int n = 0;
     82     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
     83          aiter.Next(), ++n) {
     84       const A &arc = aiter.Value();
     85       p += exp(-arc.weight.Value());
     86       if (p > r * sum) return n;
     87     }
     88     return n;
     89   }
     90 };
     91 
     92 // Convenience definitions
     93 typedef LogProbArcSelector<StdArc> StdArcSelector;
     94 typedef LogProbArcSelector<LogArc> LogArcSelector;
     95 
     96 
     97 // Options for random path generation.
     98 template <class S>
     99 struct RandGenOptions {
    100   const S &arc_selector;  // How an arc is selected at a state
    101   int max_length;         // Maximum path length
    102   size_t npath;           // # of paths to generate
    103 
    104   // These are used internally by RandGen
    105   int64 source;           // 'ifst' state to expand
    106   int64 dest;             // 'ofst' state to append
    107 
    108   RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1)
    109     : arc_selector(sel), max_length(len), npath(n),
    110        source(kNoStateId), dest(kNoStateId) {}
    111 };
    112 
    113 
    114 // Randomly generate paths through an FST; details controlled by
    115 // RandGenOptions.
    116 template<class Arc, class ArcSelector>
    117 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
    118 	     const RandGenOptions<ArcSelector> &opts) {
    119   typedef typename Arc::Weight Weight;
    120 
    121   if (opts.npath == 0 || opts.max_length == 0 || ifst.Start() == kNoStateId)
    122     return;
    123 
    124   if (opts.source == kNoStateId) {   // first call
    125     ofst->DeleteStates();
    126     ofst->SetInputSymbols(ifst.InputSymbols());
    127     ofst->SetOutputSymbols(ifst.OutputSymbols());
    128     ofst->SetStart(ofst->AddState());
    129     RandGenOptions<ArcSelector> nopts(opts);
    130     nopts.source = ifst.Start();
    131     nopts.dest = ofst->Start();
    132     for (; nopts.npath > 0; --nopts.npath)
    133       RandGen(ifst, ofst, nopts);
    134   } else {
    135     if (ifst.NumArcs(opts.source) == 0 &&
    136 	ifst.Final(opts.source) == Weight::Zero())  // Non-coaccessible
    137       return;
    138     // Pick a random transition from the source state
    139     size_t n = opts.arc_selector(ifst, opts.source);
    140     if (n == ifst.NumArcs(opts.source)) {  // Take 'super-final' transition
    141       ofst->SetFinal(opts.dest, Weight::One());
    142     } else {
    143       ArcIterator< Fst<Arc> > aiter(ifst, opts.source);
    144       aiter.Seek(n);
    145       const Arc &iarc = aiter.Value();
    146       Arc oarc(iarc.ilabel, iarc.olabel, Weight::One(), ofst->AddState());
    147       ofst->AddArc(opts.dest, oarc);
    148 
    149       RandGenOptions<ArcSelector> nopts(opts);
    150       nopts.source = iarc.nextstate;
    151       nopts.dest = oarc.nextstate;
    152       --nopts.max_length;
    153       RandGen(ifst, ofst, nopts);
    154     }
    155   }
    156 }
    157 
    158 // Randomly generate a path through an FST with the uniform distribution
    159 // over the transitions.
    160 template<class Arc>
    161 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
    162   UniformArcSelector<Arc> uniform_selector;
    163   RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector);
    164   RandGen(ifst, ofst, opts);
    165 }
    166 
    167 }  // namespace fst
    168 
    169 #endif  // FST_LIB_RANDGEN_H__
    170