1 // randgen.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Function to generate random paths through an FST. 18 19 #ifndef FST_LIB_RANDGEN_H__ 20 #define FST_LIB_RANDGEN_H__ 21 22 #include <cmath> 23 #include <cstdlib> 24 #include <ctime> 25 26 #include "fst/lib/mutable-fst.h" 27 28 namespace fst { 29 30 // 31 // ARC SELECTORS - these function objects are used to select a random 32 // transition to take from an FST's state. They should return a number 33 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th 34 // transition is selected. If N == NumArcs(), then the final weight at 35 // that state is selected (i.e., the 'super-final' transition is selected). 36 // It can be assumed these will not be called unless either there 37 // are transitions leaving the state and/or the state is final. 38 // 39 40 // Randomly selects a transition using the uniform distribution. 41 template <class A> 42 struct UniformArcSelector { 43 typedef typename A::StateId StateId; 44 typedef typename A::Weight Weight; 45 46 UniformArcSelector(int seed = time(0)) { srand(seed); } 47 48 size_t operator()(const Fst<A> &fst, StateId s) const { 49 double r = rand()/(RAND_MAX + 1.0); 50 size_t n = fst.NumArcs(s); 51 if (fst.Final(s) != Weight::Zero()) 52 ++n; 53 return static_cast<size_t>(r * n); 54 } 55 }; 56 57 // Randomly selects a transition w.r.t. the weights treated as negative 58 // log probabilities after normalizing for the total weight leaving 59 // the state). Weight::zero transitions are disregarded. 60 // Assumes Weight::Value() accesses the floating point 61 // representation of the weight. 62 template <class A> 63 struct LogProbArcSelector { 64 typedef typename A::StateId StateId; 65 typedef typename A::Weight Weight; 66 67 LogProbArcSelector(int seed = time(0)) { srand(seed); } 68 69 size_t operator()(const Fst<A> &fst, StateId s) const { 70 // Find total weight leaving state 71 double sum = 0.0; 72 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); 73 aiter.Next()) { 74 const A &arc = aiter.Value(); 75 sum += exp(-arc.weight.Value()); 76 } 77 sum += exp(-fst.Final(s).Value()); 78 79 double r = rand()/(RAND_MAX + 1.0); 80 double p = 0.0; 81 int n = 0; 82 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); 83 aiter.Next(), ++n) { 84 const A &arc = aiter.Value(); 85 p += exp(-arc.weight.Value()); 86 if (p > r * sum) return n; 87 } 88 return n; 89 } 90 }; 91 92 // Convenience definitions 93 typedef LogProbArcSelector<StdArc> StdArcSelector; 94 typedef LogProbArcSelector<LogArc> LogArcSelector; 95 96 97 // Options for random path generation. 98 template <class S> 99 struct RandGenOptions { 100 const S &arc_selector; // How an arc is selected at a state 101 int max_length; // Maximum path length 102 size_t npath; // # of paths to generate 103 104 // These are used internally by RandGen 105 int64 source; // 'ifst' state to expand 106 int64 dest; // 'ofst' state to append 107 108 RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1) 109 : arc_selector(sel), max_length(len), npath(n), 110 source(kNoStateId), dest(kNoStateId) {} 111 }; 112 113 114 // Randomly generate paths through an FST; details controlled by 115 // RandGenOptions. 116 template<class Arc, class ArcSelector> 117 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 118 const RandGenOptions<ArcSelector> &opts) { 119 typedef typename Arc::Weight Weight; 120 121 if (opts.npath == 0 || opts.max_length == 0 || ifst.Start() == kNoStateId) 122 return; 123 124 if (opts.source == kNoStateId) { // first call 125 ofst->DeleteStates(); 126 ofst->SetInputSymbols(ifst.InputSymbols()); 127 ofst->SetOutputSymbols(ifst.OutputSymbols()); 128 ofst->SetStart(ofst->AddState()); 129 RandGenOptions<ArcSelector> nopts(opts); 130 nopts.source = ifst.Start(); 131 nopts.dest = ofst->Start(); 132 for (; nopts.npath > 0; --nopts.npath) 133 RandGen(ifst, ofst, nopts); 134 } else { 135 if (ifst.NumArcs(opts.source) == 0 && 136 ifst.Final(opts.source) == Weight::Zero()) // Non-coaccessible 137 return; 138 // Pick a random transition from the source state 139 size_t n = opts.arc_selector(ifst, opts.source); 140 if (n == ifst.NumArcs(opts.source)) { // Take 'super-final' transition 141 ofst->SetFinal(opts.dest, Weight::One()); 142 } else { 143 ArcIterator< Fst<Arc> > aiter(ifst, opts.source); 144 aiter.Seek(n); 145 const Arc &iarc = aiter.Value(); 146 Arc oarc(iarc.ilabel, iarc.olabel, Weight::One(), ofst->AddState()); 147 ofst->AddArc(opts.dest, oarc); 148 149 RandGenOptions<ArcSelector> nopts(opts); 150 nopts.source = iarc.nextstate; 151 nopts.dest = oarc.nextstate; 152 --nopts.max_length; 153 RandGen(ifst, ofst, nopts); 154 } 155 } 156 } 157 158 // Randomly generate a path through an FST with the uniform distribution 159 // over the transitions. 160 template<class Arc> 161 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) { 162 UniformArcSelector<Arc> uniform_selector; 163 RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector); 164 RandGen(ifst, ofst, opts); 165 } 166 167 } // namespace fst 168 169 #endif // FST_LIB_RANDGEN_H__ 170