1 // shortest-path.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Functions to find shortest paths in an FST. 20 21 #ifndef FST_LIB_SHORTEST_PATH_H__ 22 #define FST_LIB_SHORTEST_PATH_H__ 23 24 #include <functional> 25 #include <utility> 26 using std::pair; using std::make_pair; 27 #include <vector> 28 using std::vector; 29 30 #include <fst/cache.h> 31 #include <fst/determinize.h> 32 #include <fst/queue.h> 33 #include <fst/shortest-distance.h> 34 #include <fst/test-properties.h> 35 36 37 namespace fst { 38 39 template <class Arc, class Queue, class ArcFilter> 40 struct ShortestPathOptions 41 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> { 42 typedef typename Arc::StateId StateId; 43 typedef typename Arc::Weight Weight; 44 size_t nshortest; // return n-shortest paths 45 bool unique; // only return paths with distinct input strings 46 bool has_distance; // distance vector already contains the 47 // shortest distance from the initial state 48 bool first_path; // Single shortest path stops after finding the first 49 // path to a final state. That path is the shortest path 50 // only when using the ShortestFirstQueue and 51 // only when all the weights in the FST are between 52 // One() and Zero() according to NaturalLess. 53 Weight weight_threshold; // pruning weight threshold. 54 StateId state_threshold; // pruning state threshold. 55 56 ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false, 57 bool hasdist = false, float d = kDelta, 58 bool fp = false, Weight w = Weight::Zero(), 59 StateId s = kNoStateId) 60 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d), 61 nshortest(n), unique(u), has_distance(hasdist), first_path(fp), 62 weight_threshold(w), state_threshold(s) {} 63 }; 64 65 66 // Shortest-path algorithm: normally not called directly; prefer 67 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in 68 // 'ifst'. 'distance' returns the shortest distances from the source 69 // state to each state in 'ifst'. 'opts' is used to specify options 70 // such as the queue discipline, the arc filter and delta. 71 // 72 // The shortest path is the lowest weight path w.r.t. the natural 73 // semiring order. 74 // 75 // The weights need to be right distributive and have the path (kPath) 76 // property. 77 template<class Arc, class Queue, class ArcFilter> 78 void SingleShortestPath(const Fst<Arc> &ifst, 79 MutableFst<Arc> *ofst, 80 vector<typename Arc::Weight> *distance, 81 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { 82 typedef typename Arc::StateId StateId; 83 typedef typename Arc::Weight Weight; 84 85 ofst->DeleteStates(); 86 ofst->SetInputSymbols(ifst.InputSymbols()); 87 ofst->SetOutputSymbols(ifst.OutputSymbols()); 88 89 if (ifst.Start() == kNoStateId) { 90 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); 91 return; 92 } 93 94 vector<bool> enqueued; 95 vector<StateId> parent; 96 vector<Arc> arc_parent; 97 98 Queue *state_queue = opts.state_queue; 99 StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; 100 Weight f_distance = Weight::Zero(); 101 StateId f_parent = kNoStateId; 102 103 distance->clear(); 104 state_queue->Clear(); 105 if (opts.nshortest != 1) { 106 FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath" 107 << " instead"; 108 ofst->SetProperties(kError, kError); 109 return; 110 } 111 if (opts.weight_threshold != Weight::Zero() || 112 opts.state_threshold != kNoStateId) { 113 FSTERROR() << 114 "SingleShortestPath: weight and state thresholds not applicable"; 115 ofst->SetProperties(kError, kError); 116 return; 117 } 118 if ((Weight::Properties() & (kPath | kRightSemiring)) 119 != (kPath | kRightSemiring)) { 120 FSTERROR() << "SingleShortestPath: Weight needs to have the path" 121 << " property and be right distributive: " << Weight::Type(); 122 ofst->SetProperties(kError, kError); 123 return; 124 } 125 while (distance->size() < source) { 126 distance->push_back(Weight::Zero()); 127 enqueued.push_back(false); 128 parent.push_back(kNoStateId); 129 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); 130 } 131 distance->push_back(Weight::One()); 132 parent.push_back(kNoStateId); 133 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); 134 state_queue->Enqueue(source); 135 enqueued.push_back(true); 136 137 while (!state_queue->Empty()) { 138 StateId s = state_queue->Head(); 139 state_queue->Dequeue(); 140 enqueued[s] = false; 141 Weight sd = (*distance)[s]; 142 if (ifst.Final(s) != Weight::Zero()) { 143 Weight w = Times(sd, ifst.Final(s)); 144 if (f_distance != Plus(f_distance, w)) { 145 f_distance = Plus(f_distance, w); 146 f_parent = s; 147 } 148 if (!f_distance.Member()) { 149 ofst->SetProperties(kError, kError); 150 return; 151 } 152 if (opts.first_path) 153 break; 154 } 155 for (ArcIterator< Fst<Arc> > aiter(ifst, s); 156 !aiter.Done(); 157 aiter.Next()) { 158 const Arc &arc = aiter.Value(); 159 while (distance->size() <= arc.nextstate) { 160 distance->push_back(Weight::Zero()); 161 enqueued.push_back(false); 162 parent.push_back(kNoStateId); 163 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), 164 kNoStateId)); 165 } 166 Weight &nd = (*distance)[arc.nextstate]; 167 Weight w = Times(sd, arc.weight); 168 if (nd != Plus(nd, w)) { 169 nd = Plus(nd, w); 170 if (!nd.Member()) { 171 ofst->SetProperties(kError, kError); 172 return; 173 } 174 parent[arc.nextstate] = s; 175 arc_parent[arc.nextstate] = arc; 176 if (!enqueued[arc.nextstate]) { 177 state_queue->Enqueue(arc.nextstate); 178 enqueued[arc.nextstate] = true; 179 } else { 180 state_queue->Update(arc.nextstate); 181 } 182 } 183 } 184 } 185 186 StateId s_p = kNoStateId, d_p = kNoStateId; 187 for (StateId s = f_parent, d = kNoStateId; 188 s != kNoStateId; 189 d = s, s = parent[s]) { 190 d_p = s_p; 191 s_p = ofst->AddState(); 192 if (d == kNoStateId) { 193 ofst->SetFinal(s_p, ifst.Final(f_parent)); 194 } else { 195 arc_parent[d].nextstate = d_p; 196 ofst->AddArc(s_p, arc_parent[d]); 197 } 198 } 199 ofst->SetStart(s_p); 200 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); 201 ofst->SetProperties( 202 ShortestPathProperties(ofst->Properties(kFstProperties, false)), 203 kFstProperties); 204 } 205 206 207 template <class S, class W> 208 class ShortestPathCompare { 209 public: 210 typedef S StateId; 211 typedef W Weight; 212 typedef pair<StateId, Weight> Pair; 213 214 ShortestPathCompare(const vector<Pair>& pairs, 215 const vector<Weight>& distance, 216 StateId sfinal, float d) 217 : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {} 218 219 bool operator()(const StateId x, const StateId y) const { 220 const Pair &px = pairs_[x]; 221 const Pair &py = pairs_[y]; 222 Weight dx = px.first == superfinal_ ? Weight::One() : 223 px.first < distance_.size() ? distance_[px.first] : Weight::Zero(); 224 Weight dy = py.first == superfinal_ ? Weight::One() : 225 py.first < distance_.size() ? distance_[py.first] : Weight::Zero(); 226 Weight wx = Times(dx, px.second); 227 Weight wy = Times(dy, py.second); 228 // Penalize complete paths to ensure correct results with inexact weights. 229 // This forms a strict weak order so long as ApproxEqual(a, b) => 230 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b). 231 if (px.first == superfinal_ && py.first != superfinal_) { 232 return less_(wy, wx) || ApproxEqual(wx, wy, delta_); 233 } else if (py.first == superfinal_ && px.first != superfinal_) { 234 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_); 235 } else { 236 return less_(wy, wx); 237 } 238 } 239 240 private: 241 const vector<Pair> &pairs_; 242 const vector<Weight> &distance_; 243 StateId superfinal_; 244 float delta_; 245 NaturalLess<Weight> less_; 246 }; 247 248 249 // N-Shortest-path algorithm: implements the core n-shortest path 250 // algorithm. The output is built REVERSED. See below for versions with 251 // more options and not reversed. 252 // 253 // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'. 254 // 'distance' must contain the shortest distance from each state to a final 255 // state in 'ifst'. 'delta' is the convergence delta. 256 // 257 // The n-shortest paths are the n-lowest weight paths w.r.t. the 258 // natural semiring order. The single path that can be read from the 259 // ith of at most n transitions leaving the initial state of 'ofst' is 260 // the ith shortest path. Disregarding the initial state and initial 261 // transitions, the n-shortest paths, in fact, form a tree rooted at 262 // the single final state. 263 // 264 // The weights need to be left and right distributive (kSemiring) and 265 // have the path (kPath) property. 266 // 267 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for 268 // the n-best-strings problem", ICSLP 2002. The algorithm relies on 269 // the shortest-distance algorithm. There are some issues with the 270 // pseudo-code as written in the paper (viz., line 11). 271 // 272 // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and 273 // and at any state in its expansion the values of distance vector need only 274 // be defined at that time for the states that are known to exist. 275 template<class Arc, class RevArc> 276 void NShortestPath(const Fst<RevArc> &ifst, 277 MutableFst<Arc> *ofst, 278 const vector<typename Arc::Weight> &distance, 279 size_t n, 280 float delta = kDelta, 281 typename Arc::Weight weight_threshold = Arc::Weight::Zero(), 282 typename Arc::StateId state_threshold = kNoStateId) { 283 typedef typename Arc::StateId StateId; 284 typedef typename Arc::Weight Weight; 285 typedef pair<StateId, Weight> Pair; 286 typedef typename RevArc::Weight RevWeight; 287 288 if (n <= 0) return; 289 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { 290 FSTERROR() << "NShortestPath: Weight needs to have the " 291 << "path property and be distributive: " 292 << Weight::Type(); 293 ofst->SetProperties(kError, kError); 294 return; 295 } 296 ofst->DeleteStates(); 297 ofst->SetInputSymbols(ifst.InputSymbols()); 298 ofst->SetOutputSymbols(ifst.OutputSymbols()); 299 // Each state in 'ofst' corresponds to a path with weight w from the 300 // initial state of 'ifst' to a state s in 'ifst', that can be 301 // characterized by a pair (s,w). The vector 'pairs' maps each 302 // state in 'ofst' to the corresponding pair maps states in OFST to 303 // the corresponding pair (s,w). 304 vector<Pair> pairs; 305 // The supefinal state is denoted by -1, 'compare' knows that the 306 // distance from 'superfinal' to the final state is 'Weight::One()', 307 // hence 'distance[superfinal]' is not needed. 308 StateId superfinal = -1; 309 ShortestPathCompare<StateId, Weight> 310 compare(pairs, distance, superfinal, delta); 311 vector<StateId> heap; 312 // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst' 313 // which corresponding pair contains 's' ,i.e. , it is number of 314 // paths computed so far to 's'. Valid for 's == -1' (superfinal). 315 vector<int> r; 316 NaturalLess<Weight> less; 317 if (ifst.Start() == kNoStateId || 318 distance.size() <= ifst.Start() || 319 distance[ifst.Start()] == Weight::Zero() || 320 less(weight_threshold, Weight::One()) || 321 state_threshold == 0) { 322 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); 323 return; 324 } 325 ofst->SetStart(ofst->AddState()); 326 StateId final = ofst->AddState(); 327 ofst->SetFinal(final, Weight::One()); 328 while (pairs.size() <= final) 329 pairs.push_back(Pair(kNoStateId, Weight::Zero())); 330 pairs[final] = Pair(ifst.Start(), Weight::One()); 331 heap.push_back(final); 332 Weight limit = Times(distance[ifst.Start()], weight_threshold); 333 334 while (!heap.empty()) { 335 pop_heap(heap.begin(), heap.end(), compare); 336 StateId state = heap.back(); 337 Pair p = pairs[state]; 338 heap.pop_back(); 339 Weight d = p.first == superfinal ? Weight::One() : 340 p.first < distance.size() ? distance[p.first] : Weight::Zero(); 341 342 if (less(limit, Times(d, p.second)) || 343 (state_threshold != kNoStateId && 344 ofst->NumStates() >= state_threshold)) 345 continue; 346 347 while (r.size() <= p.first + 1) r.push_back(0); 348 ++r[p.first + 1]; 349 if (p.first == superfinal) 350 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state)); 351 if ((p.first == superfinal) && (r[p.first + 1] == n)) break; 352 if (r[p.first + 1] > n) continue; 353 if (p.first == superfinal) continue; 354 355 for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first); 356 !aiter.Done(); 357 aiter.Next()) { 358 const RevArc &rarc = aiter.Value(); 359 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate); 360 Weight w = Times(p.second, arc.weight); 361 StateId next = ofst->AddState(); 362 pairs.push_back(Pair(arc.nextstate, w)); 363 arc.nextstate = state; 364 ofst->AddArc(next, arc); 365 heap.push_back(next); 366 push_heap(heap.begin(), heap.end(), compare); 367 } 368 369 Weight finalw = ifst.Final(p.first).Reverse(); 370 if (finalw != Weight::Zero()) { 371 Weight w = Times(p.second, finalw); 372 StateId next = ofst->AddState(); 373 pairs.push_back(Pair(superfinal, w)); 374 ofst->AddArc(next, Arc(0, 0, finalw, state)); 375 heap.push_back(next); 376 push_heap(heap.begin(), heap.end(), compare); 377 } 378 } 379 Connect(ofst); 380 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); 381 ofst->SetProperties( 382 ShortestPathProperties(ofst->Properties(kFstProperties, false)), 383 kFstProperties); 384 } 385 386 387 // N-Shortest-path algorithm: this version allow fine control 388 // via the options argument. See below for a simpler interface. 389 // 390 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns 391 // the shortest distances from the source state to each state in 392 // 'ifst'. 'opts' is used to specify options such as the number of 393 // paths to return, whether they need to have distinct input 394 // strings, the queue discipline, the arc filter and the convergence 395 // delta. 396 // 397 // The n-shortest paths are the n-lowest weight paths w.r.t. the 398 // natural semiring order. The single path that can be read from the 399 // ith of at most n transitions leaving the initial state of 'ofst' is 400 // the ith shortest path. Disregarding the initial state and initial 401 // transitions, The n-shortest paths, in fact, form a tree rooted at 402 // the single final state. 403 404 // The weights need to be right distributive and have the path (kPath) 405 // property. They need to be left distributive as well for nshortest 406 // > 1. 407 // 408 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for 409 // the n-best-strings problem", ICSLP 2002. The algorithm relies on 410 // the shortest-distance algorithm. There are some issues with the 411 // pseudo-code as written in the paper (viz., line 11). 412 template<class Arc, class Queue, class ArcFilter> 413 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 414 vector<typename Arc::Weight> *distance, 415 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { 416 typedef typename Arc::StateId StateId; 417 typedef typename Arc::Weight Weight; 418 typedef ReverseArc<Arc> ReverseArc; 419 420 size_t n = opts.nshortest; 421 if (n == 1) { 422 SingleShortestPath(ifst, ofst, distance, opts); 423 return; 424 } 425 if (n <= 0) return; 426 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { 427 FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the " 428 << "path property and be distributive: " 429 << Weight::Type(); 430 ofst->SetProperties(kError, kError); 431 return; 432 } 433 if (!opts.has_distance) { 434 ShortestDistance(ifst, distance, opts); 435 if (distance->size() == 1 && !(*distance)[0].Member()) { 436 ofst->SetProperties(kError, kError); 437 return; 438 } 439 } 440 // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is 441 // the distance to the final state in 'rfst', 'ofst' is built as the 442 // reverse of the tree of n-shortest path in 'rfst'. 443 VectorFst<ReverseArc> rfst; 444 Reverse(ifst, &rfst); 445 Weight d = Weight::Zero(); 446 for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0); 447 !aiter.Done(); aiter.Next()) { 448 const ReverseArc &arc = aiter.Value(); 449 StateId s = arc.nextstate - 1; 450 if (s < distance->size()) 451 d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s])); 452 } 453 distance->insert(distance->begin(), d); 454 455 if (!opts.unique) { 456 NShortestPath(rfst, ofst, *distance, n, opts.delta, 457 opts.weight_threshold, opts.state_threshold); 458 } else { 459 vector<Weight> ddistance; 460 DeterminizeFstOptions<ReverseArc> dopts(opts.delta); 461 DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts); 462 NShortestPath(dfst, ofst, ddistance, n, opts.delta, 463 opts.weight_threshold, opts.state_threshold); 464 } 465 distance->erase(distance->begin()); 466 } 467 468 469 // Shortest-path algorithm: simplified interface. See above for a 470 // version that allows finer control. 471 // 472 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue 473 // discipline is automatically selected. When 'unique' == true, only 474 // paths with distinct input labels are returned. 475 // 476 // The n-shortest paths are the n-lowest weight paths w.r.t. the 477 // natural semiring order. The single path that can be read from the 478 // ith of at most n transitions leaving the initial state of 'ofst' is 479 // the ith best path. 480 // 481 // The weights need to be right distributive and have the path 482 // (kPath) property. 483 template<class Arc> 484 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 485 size_t n = 1, bool unique = false, 486 bool first_path = false, 487 typename Arc::Weight weight_threshold = Arc::Weight::Zero(), 488 typename Arc::StateId state_threshold = kNoStateId) { 489 vector<typename Arc::Weight> distance; 490 AnyArcFilter<Arc> arc_filter; 491 AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter); 492 ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>, 493 AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false, 494 kDelta, first_path, weight_threshold, 495 state_threshold); 496 ShortestPath(ifst, ofst, &distance, opts); 497 } 498 499 } // namespace fst 500 501 #endif // FST_LIB_SHORTEST_PATH_H__ 502