1 // shortest-path.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Functions to find shortest paths in an FST. 19 20 #ifndef FST_LIB_SHORTEST_PATH_H__ 21 #define FST_LIB_SHORTEST_PATH_H__ 22 23 #include <functional> 24 25 #include "fst/lib/cache.h" 26 #include "fst/lib/queue.h" 27 #include "fst/lib/shortest-distance.h" 28 #include "fst/lib/test-properties.h" 29 30 namespace fst { 31 32 template <class Arc, class Queue, class ArcFilter> 33 struct ShortestPathOptions 34 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> { 35 typedef typename Arc::StateId StateId; 36 37 size_t nshortest; // return n-shortest paths 38 bool unique; // only return paths with distinct input strings 39 bool has_distance; // distance vector already contains the 40 // shortest distance from the initial state 41 42 ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false, 43 bool hasdist = false, float d = kDelta) 44 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d), 45 nshortest(n), unique(u), has_distance(hasdist) {} 46 }; 47 48 49 // Shortest-path algorithm: normally not called directly; prefer 50 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in 51 // 'ifst'. 'distance' returns the shortest distances from the source 52 // state to each state in 'ifst'. 'opts' is used to specify options 53 // such as the queue discipline, the arc filter and delta. 54 // 55 // The shortest path is the lowest weight path w.r.t. the natural 56 // semiring order. 57 // 58 // The weights need to be right distributive and have the path (kPath) 59 // property. 60 template<class Arc, class Queue, class ArcFilter> 61 void SingleShortestPath(const Fst<Arc> &ifst, 62 MutableFst<Arc> *ofst, 63 vector<typename Arc::Weight> *distance, 64 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { 65 typedef typename Arc::StateId StateId; 66 typedef typename Arc::Weight Weight; 67 68 ofst->DeleteStates(); 69 ofst->SetInputSymbols(ifst.InputSymbols()); 70 ofst->SetOutputSymbols(ifst.OutputSymbols()); 71 72 if (ifst.Start() == kNoStateId) 73 return; 74 75 vector<Weight> rdistance; 76 vector<bool> enqueued; 77 vector<StateId> parent; 78 vector<Arc> arc_parent; 79 80 Queue *state_queue = opts.state_queue; 81 StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; 82 Weight f_distance = Weight::Zero(); 83 StateId f_parent = kNoStateId; 84 85 distance->clear(); 86 state_queue->Clear(); 87 if (opts.nshortest != 1) 88 LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath" 89 << " instead"; 90 if ((Weight::Properties() & (kPath | kRightSemiring)) 91 != (kPath | kRightSemiring)) 92 LOG(FATAL) << "SingleShortestPath: Weight needs to have the path" 93 << " property and be right distributive: " << Weight::Type(); 94 95 while (distance->size() < source) { 96 distance->push_back(Weight::Zero()); 97 enqueued.push_back(false); 98 parent.push_back(kNoStateId); 99 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); 100 } 101 distance->push_back(Weight::One()); 102 parent.push_back(kNoStateId); 103 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); 104 state_queue->Enqueue(source); 105 enqueued.push_back(true); 106 107 while (!state_queue->Empty()) { 108 StateId s = state_queue->Head(); 109 state_queue->Dequeue(); 110 enqueued[s] = false; 111 Weight sd = (*distance)[s]; 112 for (ArcIterator< Fst<Arc> > aiter(ifst, s); 113 !aiter.Done(); 114 aiter.Next()) { 115 const Arc &arc = aiter.Value(); 116 while (distance->size() <= arc.nextstate) { 117 distance->push_back(Weight::Zero()); 118 enqueued.push_back(false); 119 parent.push_back(kNoStateId); 120 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), 121 kNoStateId)); 122 } 123 Weight &nd = (*distance)[arc.nextstate]; 124 Weight w = Times(sd, arc.weight); 125 if (nd != Plus(nd, w)) { 126 nd = Plus(nd, w); 127 parent[arc.nextstate] = s; 128 arc_parent[arc.nextstate] = arc; 129 if (!enqueued[arc.nextstate]) { 130 state_queue->Enqueue(arc.nextstate); 131 enqueued[arc.nextstate] = true; 132 } else { 133 state_queue->Update(arc.nextstate); 134 } 135 } 136 } 137 if (ifst.Final(s) != Weight::Zero()) { 138 Weight w = Times(sd, ifst.Final(s)); 139 if (f_distance != Plus(f_distance, w)) { 140 f_distance = Plus(f_distance, w); 141 f_parent = s; 142 } 143 } 144 } 145 (*distance)[source] = Weight::One(); 146 parent[source] = kNoStateId; 147 148 StateId s_p = kNoStateId, d_p = kNoStateId; 149 for (StateId s = f_parent, d = kNoStateId; 150 s != kNoStateId; 151 d = s, s = parent[s]) { 152 enqueued[s] = true; 153 d_p = s_p; 154 s_p = ofst->AddState(); 155 if (d == kNoStateId) { 156 ofst->SetFinal(s_p, ifst.Final(f_parent)); 157 } else { 158 arc_parent[d].nextstate = d_p; 159 ofst->AddArc(s_p, arc_parent[d]); 160 } 161 } 162 ofst->SetStart(s_p); 163 } 164 165 166 template <class S, class W> 167 class ShortestPathCompare { 168 public: 169 typedef S StateId; 170 typedef W Weight; 171 typedef pair<StateId, Weight> Pair; 172 173 ShortestPathCompare(const vector<Pair>& pairs, 174 const vector<Weight>& distance, 175 StateId sfinal, float d) 176 : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {} 177 178 bool operator()(const StateId x, const StateId y) const { 179 const Pair &px = pairs_[x]; 180 const Pair &py = pairs_[y]; 181 Weight wx = Times(distance_[px.first], px.second); 182 Weight wy = Times(distance_[py.first], py.second); 183 // Penalize complete paths to ensure correct results with inexact weights. 184 // This forms a strict weak order so long as ApproxEqual(a, b) => 185 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b). 186 if (px.first == superfinal_ && py.first != superfinal_) { 187 return less_(wy, wx) || ApproxEqual(wx, wy, delta_); 188 } else if (py.first == superfinal_ && px.first != superfinal_) { 189 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_); 190 } else { 191 return less_(wy, wx); 192 } 193 } 194 195 private: 196 const vector<Pair> &pairs_; 197 const vector<Weight> &distance_; 198 StateId superfinal_; 199 float delta_; 200 NaturalLess<Weight> less_; 201 }; 202 203 204 // N-Shortest-path algorithm: this version allow fine control 205 // via the otpions argument. See below for a simpler interface. 206 // 207 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns 208 // the shortest distances from the source state to each state in 209 // 'ifst'. 'opts' is used to specify options such as the number of 210 // paths to return, whether they need to have distinct input 211 // strings, the queue discipline, the arc filter and the convergence 212 // delta. 213 // 214 // The n-shortest paths are the n-lowest weight paths w.r.t. the 215 // natural semiring order. The single path that can be 216 // read from the ith of at most n transitions leaving the initial 217 // state of 'ofst' is the ith shortest path. 218 219 // The weights need to be right distributive and have the path (kPath) 220 // property. They need to be left distributive as well for nshortest 221 // > 1. 222 // 223 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for 224 // the n-best-strings problem", ICSLP 2002. The algorithm relies on 225 // the shortest-distance algorithm. There are some issues with the 226 // pseudo-code as written in the paper (viz., line 11). 227 template<class Arc, class Queue, class ArcFilter> 228 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 229 vector<typename Arc::Weight> *distance, 230 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { 231 typedef typename Arc::StateId StateId; 232 typedef typename Arc::Weight Weight; 233 typedef pair<StateId, Weight> Pair; 234 typedef ReverseArc<Arc> ReverseArc; 235 typedef typename ReverseArc::Weight ReverseWeight; 236 237 size_t n = opts.nshortest; 238 239 if (n == 1) { 240 SingleShortestPath(ifst, ofst, distance, opts); 241 return; 242 } 243 ofst->DeleteStates(); 244 ofst->SetInputSymbols(ifst.InputSymbols()); 245 ofst->SetOutputSymbols(ifst.OutputSymbols()); 246 if (n <= 0) return; 247 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) 248 LOG(FATAL) << "ShortestPath: n-shortest: Weight needs to have the " 249 << "path property and be distributive: " 250 << Weight::Type(); 251 if (opts.unique) 252 LOG(FATAL) << "ShortestPath: n-shortest-string algorithm not " 253 << "currently implemented"; 254 255 // Algorithm works on the reverse of 'fst' : 'rfst' 'distance' is 256 // the distance to the final state in 'rfst' 'ofst' is built as the 257 // reverse of the tree of n-shortest path in 'rfst'. 258 259 if (!opts.has_distance) 260 ShortestDistance(ifst, distance, opts); 261 VectorFst<ReverseArc> rfst; 262 Reverse(ifst, &rfst); 263 distance->insert(distance->begin(), Weight::One()); 264 while (distance->size() < rfst.NumStates()) 265 distance->push_back(Weight::Zero()); 266 267 268 // Each state in 'ofst' corresponds to a path with weight w from the 269 // initial state of 'rfst' to a state s in 'rfst', that can be 270 // characterized by a pair (s,w). The vector 'pairs' maps each 271 // state in 'ofst' to the corresponding pair maps states in OFST to 272 // the corresponding pair (s,w). 273 vector<Pair> pairs; 274 // 'r[s]', 's' state in 'fst', is the number of states in 'ofst' 275 // which corresponding pair contains 's' ,i.e. , it is number of 276 // paths computed so far to 's'. 277 StateId superfinal = distance->size(); // superfinal must be handled 278 distance->push_back(Weight::One()); // differently when unique=true 279 ShortestPathCompare<StateId, Weight> 280 compare(pairs, *distance, superfinal, opts.delta); 281 vector<StateId> heap; 282 vector<int> r; 283 while (r.size() < distance->size()) 284 r.push_back(0); 285 ofst->SetStart(ofst->AddState()); 286 StateId final = ofst->AddState(); 287 ofst->SetFinal(final, Weight::One()); 288 while (pairs.size() <= final) 289 pairs.push_back(Pair(kNoStateId, Weight::Zero())); 290 pairs[final] = Pair(rfst.Start(), Weight::One()); 291 heap.push_back(final); 292 293 while (!heap.empty()) { 294 pop_heap(heap.begin(), heap.end(), compare); 295 StateId state = heap.back(); 296 Pair p = pairs[state]; 297 heap.pop_back(); 298 299 ++r[p.first]; 300 if (p.first == superfinal) 301 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state)); 302 if ((p.first == superfinal) && (r[p.first] == n)) break; 303 if (r[p.first] > n) continue; 304 if (p.first == superfinal) 305 continue; 306 307 for (ArcIterator< Fst<ReverseArc> > aiter(rfst, p.first); 308 !aiter.Done(); 309 aiter.Next()) { 310 const ReverseArc &rarc = aiter.Value(); 311 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate); 312 Weight w = Times(p.second, arc.weight); 313 StateId next = ofst->AddState(); 314 pairs.push_back(Pair(arc.nextstate, w)); 315 arc.nextstate = state; 316 ofst->AddArc(next, arc); 317 heap.push_back(next); 318 push_heap(heap.begin(), heap.end(), compare); 319 } 320 321 Weight finalw = rfst.Final(p.first).Reverse(); 322 if (finalw != Weight::Zero()) { 323 Weight w = Times(p.second, finalw); 324 StateId next = ofst->AddState(); 325 pairs.push_back(Pair(superfinal, w)); 326 ofst->AddArc(next, Arc(0, 0, finalw, state)); 327 heap.push_back(next); 328 push_heap(heap.begin(), heap.end(), compare); 329 } 330 } 331 Connect(ofst); 332 distance->erase(distance->begin()); 333 distance->pop_back(); 334 } 335 336 // Shortest-path algorithm: simplified interface. See above for a 337 // version that allows finer control. 338 339 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue 340 // discipline is automatically selected. When 'unique' == true, only 341 // paths with distinct input labels are returned. 342 // 343 // The n-shortest paths are the n-lowest weight paths w.r.t. the 344 // natural semiring order. The single path that can be read from the 345 // ith of at most n transitions leaving the initial state of 'ofst' is 346 // the ith best path. 347 // 348 // The weights need to be right distributive and have the path 349 // (kPath) property. 350 template<class Arc> 351 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 352 size_t n = 1, bool unique = false) { 353 vector<typename Arc::Weight> distance; 354 AnyArcFilter<Arc> arc_filter; 355 AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter); 356 ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>, 357 AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique); 358 ShortestPath(ifst, ofst, &distance, opts); 359 } 360 361 } // namespace fst 362 363 #endif // FST_LIB_SHORTEST_PATH_H__ 364