1 // prune.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Functions implementing pruning. 19 20 #ifndef FST_LIB_PRUNE_H__ 21 #define FST_LIB_PRUNE_H__ 22 23 #include "fst/lib/arcfilter.h" 24 #include "fst/lib/shortest-distance.h" 25 26 namespace fst { 27 28 template <class A, class ArcFilter> 29 class PruneOptions { 30 public: 31 typedef typename A::Weight Weight; 32 33 // Pruning threshold. 34 Weight threshold; 35 // Arc filter. 36 ArcFilter filter; 37 // If non-zero, passes in pre-computed shortest distance from initial state 38 // (possibly resized). 39 vector<Weight> *idistance; 40 // If non-zero, passes in pre-computed shortest distance to final states 41 // (possibly resized). 42 vector<Weight> *fdistance; 43 44 PruneOptions(const Weight& t, ArcFilter f, vector<Weight> *id = 0, 45 vector<Weight> *fd = 0) 46 : threshold(t), filter(f), idistance(id), fdistance(fd) {} 47 }; 48 49 50 // Pruning algorithm: this version modifies its input and it takes an 51 // options class as an argment. Delete states and arcs in 'fst' that 52 // do not belong to a successful path whose weight is no more than 53 // 'opts.threshold' Times() the weight of the shortest path. Weights 54 // need to be commutative and have the path property. 55 template <class Arc, class ArcFilter> 56 void Prune(MutableFst<Arc> *fst, 57 const PruneOptions<Arc, ArcFilter> &opts) { 58 typedef typename Arc::Weight Weight; 59 typedef typename Arc::StateId StateId; 60 61 if ((Weight::Properties() & (kPath | kCommutative)) 62 != (kPath | kCommutative)) 63 LOG(FATAL) << "Prune: Weight needs to have the path property and" 64 << " be commutative: " 65 << Weight::Type(); 66 67 StateId ns = fst->NumStates(); 68 if (ns == 0) return; 69 70 vector<Weight> *idistance = opts.idistance; 71 vector<Weight> *fdistance = opts.fdistance; 72 73 if (!idistance) { 74 idistance = new vector<Weight>(ns, Weight::Zero()); 75 ShortestDistance(*fst, idistance, false); 76 } else { 77 idistance->resize(ns, Weight::Zero()); 78 } 79 80 if (!fdistance) { 81 fdistance = new vector<Weight>(ns, Weight::Zero()); 82 ShortestDistance(*fst, fdistance, true); 83 } else { 84 fdistance->resize(ns, Weight::Zero()); 85 } 86 87 vector<StateId> dead; 88 dead.push_back(fst->AddState()); 89 NaturalLess<Weight> less; 90 Weight ceiling = Times((*fdistance)[fst->Start()], opts.threshold); 91 92 for (StateId state = 0; state < ns; ++state) { 93 if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) { 94 dead.push_back(state); 95 continue; 96 } 97 for (MutableArcIterator< MutableFst<Arc> > it(fst, state); 98 !it.Done(); 99 it.Next()) { 100 Arc arc = it.Value(); 101 if (!opts.filter(arc)) continue; 102 Weight weight = Times(Times((*idistance)[state], arc.weight), 103 (*fdistance)[arc.nextstate]); 104 if(less(ceiling, weight)) { 105 arc.nextstate = dead[0]; 106 it.SetValue(arc); 107 } 108 } 109 if (less(ceiling, Times((*idistance)[state], fst->Final(state)))) 110 fst->SetFinal(state, Weight::Zero()); 111 } 112 113 fst->DeleteStates(dead); 114 115 if (!opts.idistance) 116 delete idistance; 117 if (!opts.fdistance) 118 delete fdistance; 119 } 120 121 122 // Pruning algorithm: this version modifies its input and simply takes 123 // the pruning threshold as an argument. Delete states and arcs in 124 // 'fst' that do not belong to a successful path whose weight is no 125 // more than 'opts.threshold' Times() the weight of the shortest 126 // path. Weights need to be commutative and have the path property. 127 template <class Arc> 128 void Prune(MutableFst<Arc> *fst, typename Arc::Weight threshold) { 129 PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>()); 130 Prune(fst, opts); 131 } 132 133 134 // Pruning algorithm: this version writes the pruned input Fst to an 135 // output MutableFst and it takes an options class as an argument. 136 // 'ofst' contains states and arcs that belong to a successful path in 137 // 'ifst' whose weight is no more than 'opts.threshold' Times() the 138 // weight of the shortest path. Weights need to be commutative and 139 // have the path property. 140 template <class Arc, class ArcFilter> 141 void Prune(const Fst<Arc> &ifst, 142 MutableFst<Arc> *ofst, 143 const PruneOptions<Arc, ArcFilter> &opts) { 144 typedef typename Arc::Weight Weight; 145 typedef typename Arc::StateId StateId; 146 147 if ((Weight::Properties() & (kPath | kCommutative)) 148 != (kPath | kCommutative)) 149 LOG(FATAL) << "Prune: Weight needs to have the path property and" 150 << " be commutative: " 151 << Weight::Type(); 152 153 ofst->DeleteStates(); 154 155 if (ifst.Start() == kNoStateId) 156 return; 157 158 vector<Weight> *idistance = opts.idistance; 159 vector<Weight> *fdistance = opts.fdistance; 160 161 if (!idistance) { 162 idistance = new vector<Weight>; 163 ShortestDistance(ifst, idistance, false); 164 } 165 166 if (!fdistance) { 167 fdistance = new vector<Weight>; 168 ShortestDistance(ifst, fdistance, true); 169 } 170 171 vector<StateId> copy; 172 NaturalLess<Weight> less; 173 while (fdistance->size() <= ifst.Start()) 174 fdistance->push_back(Weight::Zero()); 175 Weight ceiling = Times((*fdistance)[ifst.Start()], opts.threshold); 176 177 for (StateIterator< Fst<Arc> > sit(ifst); 178 !sit.Done(); 179 sit.Next()) { 180 StateId state = sit.Value(); 181 while (idistance->size() <= state) 182 idistance->push_back(Weight::Zero()); 183 while (fdistance->size() <= state) 184 fdistance->push_back(Weight::Zero()); 185 while (copy.size() <= state) 186 copy.push_back(kNoStateId); 187 188 if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) 189 continue; 190 191 if (copy[state] == kNoStateId) 192 copy[state] = ofst->AddState(); 193 if (!less(ceiling, Times((*idistance)[state], ifst.Final(state)))) 194 ofst->SetFinal(copy[state], ifst.Final(state)); 195 196 for (ArcIterator< Fst<Arc> > ait(ifst, state); 197 !ait.Done(); 198 ait.Next()) { 199 Arc arc = ait.Value(); 200 201 if (!opts.filter(arc)) continue; 202 203 while (idistance->size() <= arc.nextstate) 204 idistance->push_back(Weight::Zero()); 205 while (fdistance->size() <= arc.nextstate) 206 fdistance->push_back(Weight::Zero()); 207 while (copy.size() <= arc.nextstate) 208 copy.push_back(kNoStateId); 209 210 Weight weight = Times(Times((*idistance)[state], arc.weight), 211 (*fdistance)[arc.nextstate]); 212 213 if (!less(ceiling, weight)) { 214 if (copy[arc.nextstate] == kNoStateId) 215 copy[arc.nextstate] = ofst->AddState(); 216 arc.nextstate = copy[arc.nextstate]; 217 ofst->AddArc(copy[state], arc); 218 } 219 } 220 } 221 222 ofst->SetStart(copy[ifst.Start()]); 223 224 if (!opts.idistance) 225 delete idistance; 226 if (!opts.fdistance) 227 delete fdistance; 228 } 229 230 231 // Pruning algorithm: this version writes the pruned input Fst to an 232 // output MutableFst and simply takes the pruning threshold as an 233 // argument. 'ofst' contains states and arcs that belong to a 234 // successful path in 'ifst' whose weight is no more than 235 // 'opts.threshold' Times() the weight of the shortest path. Weights 236 // need to be commutative and have the path property. 237 template <class Arc> 238 void Prune(const Fst<Arc> &ifst, 239 MutableFst<Arc> *ofst, 240 typename Arc::Weight threshold) { 241 PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>()); 242 Prune(ifst, ofst, opts); 243 } 244 245 } // namespace fst 246 247 #endif // FST_LIB_PRUNE_H_ 248