1 // prune.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Functions implementing pruning. 20 21 #ifndef FST_LIB_PRUNE_H__ 22 #define FST_LIB_PRUNE_H__ 23 24 #include <vector> 25 using std::vector; 26 27 #include <fst/arcfilter.h> 28 #include <fst/heap.h> 29 #include <fst/shortest-distance.h> 30 31 32 namespace fst { 33 34 template <class A, class ArcFilter> 35 class PruneOptions { 36 public: 37 typedef typename A::Weight Weight; 38 typedef typename A::StateId StateId; 39 40 // Pruning weight threshold. 41 Weight weight_threshold; 42 // Pruning state threshold. 43 StateId state_threshold; 44 // Arc filter. 45 ArcFilter filter; 46 // If non-zero, passes in pre-computed shortest distance to final states. 47 const vector<Weight> *distance; 48 // Determines the degree of convergence required when computing shortest 49 // distances. 50 float delta; 51 52 explicit PruneOptions(const Weight& w, StateId s, ArcFilter f, 53 vector<Weight> *d = 0, float e = kDelta) 54 : weight_threshold(w), 55 state_threshold(s), 56 filter(f), 57 distance(d), 58 delta(e) {} 59 private: 60 PruneOptions(); // disallow 61 }; 62 63 64 template <class S, class W> 65 class PruneCompare { 66 public: 67 typedef S StateId; 68 typedef W Weight; 69 70 PruneCompare(const vector<Weight> &idistance, 71 const vector<Weight> &fdistance) 72 : idistance_(idistance), fdistance_(fdistance) {} 73 74 bool operator()(const StateId x, const StateId y) const { 75 Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(), 76 x < fdistance_.size() ? fdistance_[x] : Weight::Zero()); 77 Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(), 78 y < fdistance_.size() ? fdistance_[y] : Weight::Zero()); 79 return less_(wx, wy); 80 } 81 82 private: 83 const vector<Weight> &idistance_; 84 const vector<Weight> &fdistance_; 85 NaturalLess<Weight> less_; 86 }; 87 88 89 90 // Pruning algorithm: this version modifies its input and it takes an 91 // options class as an argment. Delete states and arcs in 'fst' that 92 // do not belong to a successful path whose weight is no more than 93 // the weight of the shortest path Times() 'opts.weight_threshold'. 94 // When 'opts.state_threshold != kNoStateId', the resulting transducer 95 // will restricted further to have at most 'opts.state_threshold' 96 // states. Weights need to be commutative and have the path 97 // property. The weight 'w' of any cycle needs to be bounded, i.e., 98 // 'Plus(w, W::One()) = One()'. 99 template <class Arc, class ArcFilter> 100 void Prune(MutableFst<Arc> *fst, 101 const PruneOptions<Arc, ArcFilter> &opts) { 102 typedef typename Arc::Weight Weight; 103 typedef typename Arc::StateId StateId; 104 105 if ((Weight::Properties() & (kPath | kCommutative)) 106 != (kPath | kCommutative)) { 107 FSTERROR() << "Prune: Weight needs to have the path property and" 108 << " be commutative: " 109 << Weight::Type(); 110 fst->SetProperties(kError, kError); 111 return; 112 } 113 StateId ns = fst->NumStates(); 114 if (ns == 0) return; 115 vector<Weight> idistance(ns, Weight::Zero()); 116 vector<Weight> tmp; 117 if (!opts.distance) { 118 tmp.reserve(ns); 119 ShortestDistance(*fst, &tmp, true, opts.delta); 120 } 121 const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; 122 123 if ((opts.state_threshold == 0) || 124 (fdistance->size() <= fst->Start()) || 125 ((*fdistance)[fst->Start()] == Weight::Zero())) { 126 fst->DeleteStates(); 127 return; 128 } 129 PruneCompare<StateId, Weight> compare(idistance, *fdistance); 130 Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); 131 vector<bool> visited(ns, false); 132 vector<size_t> enqueued(ns, kNoKey); 133 vector<StateId> dead; 134 dead.push_back(fst->AddState()); 135 NaturalLess<Weight> less; 136 Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold); 137 138 StateId num_visited = 0; 139 StateId s = fst->Start(); 140 if (!less(limit, (*fdistance)[s])) { 141 idistance[s] = Weight::One(); 142 enqueued[s] = heap.Insert(s); 143 ++num_visited; 144 } 145 146 while (!heap.Empty()) { 147 s = heap.Top(); 148 heap.Pop(); 149 enqueued[s] = kNoKey; 150 visited[s] = true; 151 if (less(limit, Times(idistance[s], fst->Final(s)))) 152 fst->SetFinal(s, Weight::Zero()); 153 for (MutableArcIterator< MutableFst<Arc> > ait(fst, s); 154 !ait.Done(); 155 ait.Next()) { 156 Arc arc = ait.Value(); 157 if (!opts.filter(arc)) continue; 158 Weight weight = Times(Times(idistance[s], arc.weight), 159 arc.nextstate < fdistance->size() 160 ? (*fdistance)[arc.nextstate] 161 : Weight::Zero()); 162 if (less(limit, weight)) { 163 arc.nextstate = dead[0]; 164 ait.SetValue(arc); 165 continue; 166 } 167 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) 168 idistance[arc.nextstate] = Times(idistance[s], arc.weight); 169 if (visited[arc.nextstate]) continue; 170 if ((opts.state_threshold != kNoStateId) && 171 (num_visited >= opts.state_threshold)) 172 continue; 173 if (enqueued[arc.nextstate] == kNoKey) { 174 enqueued[arc.nextstate] = heap.Insert(arc.nextstate); 175 ++num_visited; 176 } else { 177 heap.Update(enqueued[arc.nextstate], arc.nextstate); 178 } 179 } 180 } 181 for (size_t i = 0; i < visited.size(); ++i) 182 if (!visited[i]) dead.push_back(i); 183 fst->DeleteStates(dead); 184 } 185 186 187 // Pruning algorithm: this version modifies its input and simply takes 188 // the pruning threshold as an argument. Delete states and arcs in 189 // 'fst' that do not belong to a successful path whose weight is no 190 // more than the weight of the shortest path Times() 191 // 'weight_threshold'. When 'state_threshold != kNoStateId', the 192 // resulting transducer will be restricted further to have at most 193 // 'opts.state_threshold' states. Weights need to be commutative and 194 // have the path property. The weight 'w' of any cycle needs to be 195 // bounded, i.e., 'Plus(w, W::One()) = One()'. 196 template <class Arc> 197 void Prune(MutableFst<Arc> *fst, 198 typename Arc::Weight weight_threshold, 199 typename Arc::StateId state_threshold = kNoStateId, 200 double delta = kDelta) { 201 PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, 202 AnyArcFilter<Arc>(), 0, delta); 203 Prune(fst, opts); 204 } 205 206 207 // Pruning algorithm: this version writes the pruned input Fst to an 208 // output MutableFst and it takes an options class as an argument. 209 // 'ofst' contains states and arcs that belong to a successful path in 210 // 'ifst' whose weight is no more than the weight of the shortest path 211 // Times() 'opts.weight_threshold'. When 'opts.state_threshold != 212 // kNoStateId', 'ofst' will be restricted further to have at most 213 // 'opts.state_threshold' states. Weights need to be commutative and 214 // have the path property. The weight 'w' of any cycle needs to be 215 // bounded, i.e., 'Plus(w, W::One()) = One()'. 216 template <class Arc, class ArcFilter> 217 void Prune(const Fst<Arc> &ifst, 218 MutableFst<Arc> *ofst, 219 const PruneOptions<Arc, ArcFilter> &opts) { 220 typedef typename Arc::Weight Weight; 221 typedef typename Arc::StateId StateId; 222 223 if ((Weight::Properties() & (kPath | kCommutative)) 224 != (kPath | kCommutative)) { 225 FSTERROR() << "Prune: Weight needs to have the path property and" 226 << " be commutative: " 227 << Weight::Type(); 228 ofst->SetProperties(kError, kError); 229 return; 230 } 231 ofst->DeleteStates(); 232 ofst->SetInputSymbols(ifst.InputSymbols()); 233 ofst->SetOutputSymbols(ifst.OutputSymbols()); 234 if (ifst.Start() == kNoStateId) 235 return; 236 NaturalLess<Weight> less; 237 if (less(opts.weight_threshold, Weight::One()) || 238 (opts.state_threshold == 0)) 239 return; 240 vector<Weight> idistance; 241 vector<Weight> tmp; 242 if (!opts.distance) 243 ShortestDistance(ifst, &tmp, true, opts.delta); 244 const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; 245 246 if ((fdistance->size() <= ifst.Start()) || 247 ((*fdistance)[ifst.Start()] == Weight::Zero())) { 248 return; 249 } 250 PruneCompare<StateId, Weight> compare(idistance, *fdistance); 251 Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); 252 vector<StateId> copy; 253 vector<size_t> enqueued; 254 vector<bool> visited; 255 256 StateId s = ifst.Start(); 257 Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(), 258 opts.weight_threshold); 259 while (copy.size() <= s) 260 copy.push_back(kNoStateId); 261 copy[s] = ofst->AddState(); 262 ofst->SetStart(copy[s]); 263 while (idistance.size() <= s) 264 idistance.push_back(Weight::Zero()); 265 idistance[s] = Weight::One(); 266 while (enqueued.size() <= s) { 267 enqueued.push_back(kNoKey); 268 visited.push_back(false); 269 } 270 enqueued[s] = heap.Insert(s); 271 272 while (!heap.Empty()) { 273 s = heap.Top(); 274 heap.Pop(); 275 enqueued[s] = kNoKey; 276 visited[s] = true; 277 if (!less(limit, Times(idistance[s], ifst.Final(s)))) 278 ofst->SetFinal(copy[s], ifst.Final(s)); 279 for (ArcIterator< Fst<Arc> > ait(ifst, s); 280 !ait.Done(); 281 ait.Next()) { 282 const Arc &arc = ait.Value(); 283 if (!opts.filter(arc)) continue; 284 Weight weight = Times(Times(idistance[s], arc.weight), 285 arc.nextstate < fdistance->size() 286 ? (*fdistance)[arc.nextstate] 287 : Weight::Zero()); 288 if (less(limit, weight)) continue; 289 if ((opts.state_threshold != kNoStateId) && 290 (ofst->NumStates() >= opts.state_threshold)) 291 continue; 292 while (idistance.size() <= arc.nextstate) 293 idistance.push_back(Weight::Zero()); 294 if (less(Times(idistance[s], arc.weight), 295 idistance[arc.nextstate])) 296 idistance[arc.nextstate] = Times(idistance[s], arc.weight); 297 while (copy.size() <= arc.nextstate) 298 copy.push_back(kNoStateId); 299 if (copy[arc.nextstate] == kNoStateId) 300 copy[arc.nextstate] = ofst->AddState(); 301 ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, 302 copy[arc.nextstate])); 303 while (enqueued.size() <= arc.nextstate) { 304 enqueued.push_back(kNoKey); 305 visited.push_back(false); 306 } 307 if (visited[arc.nextstate]) continue; 308 if (enqueued[arc.nextstate] == kNoKey) 309 enqueued[arc.nextstate] = heap.Insert(arc.nextstate); 310 else 311 heap.Update(enqueued[arc.nextstate], arc.nextstate); 312 } 313 } 314 } 315 316 317 // Pruning algorithm: this version writes the pruned input Fst to an 318 // output MutableFst and simply takes the pruning threshold as an 319 // argument. 'ofst' contains states and arcs that belong to a 320 // successful path in 'ifst' whose weight is no more than 321 // the weight of the shortest path Times() 'weight_threshold'. When 322 // 'state_threshold != kNoStateId', 'ofst' will be restricted further 323 // to have at most 'opts.state_threshold' states. Weights need to be 324 // commutative and have the path property. The weight 'w' of any cycle 325 // needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'. 326 template <class Arc> 327 void Prune(const Fst<Arc> &ifst, 328 MutableFst<Arc> *ofst, 329 typename Arc::Weight weight_threshold, 330 typename Arc::StateId state_threshold = kNoStateId, 331 float delta = kDelta) { 332 PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, 333 AnyArcFilter<Arc>(), 0, delta); 334 Prune(ifst, ofst, opts); 335 } 336 337 } // namespace fst 338 339 #endif // FST_LIB_PRUNE_H_ 340