Home | History | Annotate | Download | only in fst
      1 // shortest-path.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Functions to find shortest paths in an FST.
     20 
     21 #ifndef FST_LIB_SHORTEST_PATH_H__
     22 #define FST_LIB_SHORTEST_PATH_H__
     23 
     24 #include <functional>
     25 #include <utility>
     26 using std::pair; using std::make_pair;
     27 #include <vector>
     28 using std::vector;
     29 
     30 #include <fst/cache.h>
     31 #include <fst/determinize.h>
     32 #include <fst/queue.h>
     33 #include <fst/shortest-distance.h>
     34 #include <fst/test-properties.h>
     35 
     36 
     37 namespace fst {
     38 
     39 template <class Arc, class Queue, class ArcFilter>
     40 struct ShortestPathOptions
     41     : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
     42   typedef typename Arc::StateId StateId;
     43   typedef typename Arc::Weight Weight;
     44   size_t nshortest;   // return n-shortest paths
     45   bool unique;        // only return paths with distinct input strings
     46   bool has_distance;  // distance vector already contains the
     47                       // shortest distance from the initial state
     48   bool first_path;    // Single shortest path stops after finding the first
     49                       // path to a final state. That path is the shortest path
     50                       // only when using the ShortestFirstQueue and
     51                       // only when all the weights in the FST are between
     52                       // One() and Zero() according to NaturalLess.
     53   Weight weight_threshold;   // pruning weight threshold.
     54   StateId state_threshold;   // pruning state threshold.
     55 
     56   ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
     57                       bool hasdist = false, float d = kDelta,
     58                       bool fp = false, Weight w = Weight::Zero(),
     59                       StateId s = kNoStateId)
     60       : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
     61         nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
     62         weight_threshold(w), state_threshold(s) {}
     63 };
     64 
     65 
     66 // Shortest-path algorithm: normally not called directly; prefer
     67 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
     68 // 'ifst'. 'distance' returns the shortest distances from the source
     69 // state to each state in 'ifst'. 'opts' is used to specify options
     70 // such as the queue discipline, the arc filter and delta.
     71 //
     72 // The shortest path is the lowest weight path w.r.t. the natural
     73 // semiring order.
     74 //
     75 // The weights need to be right distributive and have the path (kPath)
     76 // property.
     77 template<class Arc, class Queue, class ArcFilter>
     78 void SingleShortestPath(const Fst<Arc> &ifst,
     79                   MutableFst<Arc> *ofst,
     80                   vector<typename Arc::Weight> *distance,
     81                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
     82   typedef typename Arc::StateId StateId;
     83   typedef typename Arc::Weight Weight;
     84 
     85   ofst->DeleteStates();
     86   ofst->SetInputSymbols(ifst.InputSymbols());
     87   ofst->SetOutputSymbols(ifst.OutputSymbols());
     88 
     89   if (ifst.Start() == kNoStateId) {
     90     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
     91     return;
     92   }
     93 
     94   vector<bool> enqueued;
     95   vector<StateId> parent;
     96   vector<Arc> arc_parent;
     97 
     98   Queue *state_queue = opts.state_queue;
     99   StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
    100   Weight f_distance = Weight::Zero();
    101   StateId f_parent = kNoStateId;
    102 
    103   distance->clear();
    104   state_queue->Clear();
    105   if (opts.nshortest != 1) {
    106     FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath"
    107                << " instead";
    108     ofst->SetProperties(kError, kError);
    109     return;
    110   }
    111   if (opts.weight_threshold != Weight::Zero() ||
    112       opts.state_threshold != kNoStateId) {
    113     FSTERROR() <<
    114         "SingleShortestPath: weight and state thresholds not applicable";
    115     ofst->SetProperties(kError, kError);
    116     return;
    117   }
    118   if ((Weight::Properties() & (kPath | kRightSemiring))
    119       != (kPath | kRightSemiring)) {
    120     FSTERROR() << "SingleShortestPath: Weight needs to have the path"
    121                << " property and be right distributive: " << Weight::Type();
    122     ofst->SetProperties(kError, kError);
    123     return;
    124   }
    125   while (distance->size() < source) {
    126     distance->push_back(Weight::Zero());
    127     enqueued.push_back(false);
    128     parent.push_back(kNoStateId);
    129     arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
    130   }
    131   distance->push_back(Weight::One());
    132   parent.push_back(kNoStateId);
    133   arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
    134   state_queue->Enqueue(source);
    135   enqueued.push_back(true);
    136 
    137   while (!state_queue->Empty()) {
    138     StateId s = state_queue->Head();
    139     state_queue->Dequeue();
    140     enqueued[s] = false;
    141     Weight sd = (*distance)[s];
    142     if (ifst.Final(s) != Weight::Zero()) {
    143       Weight w = Times(sd, ifst.Final(s));
    144       if (f_distance != Plus(f_distance, w)) {
    145         f_distance = Plus(f_distance, w);
    146         f_parent = s;
    147       }
    148       if (!f_distance.Member()) {
    149         ofst->SetProperties(kError, kError);
    150         return;
    151       }
    152       if (opts.first_path)
    153         break;
    154     }
    155     for (ArcIterator< Fst<Arc> > aiter(ifst, s);
    156          !aiter.Done();
    157          aiter.Next()) {
    158       const Arc &arc = aiter.Value();
    159       while (distance->size() <= arc.nextstate) {
    160         distance->push_back(Weight::Zero());
    161         enqueued.push_back(false);
    162         parent.push_back(kNoStateId);
    163         arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
    164                                  kNoStateId));
    165       }
    166       Weight &nd = (*distance)[arc.nextstate];
    167       Weight w = Times(sd, arc.weight);
    168       if (nd != Plus(nd, w)) {
    169         nd = Plus(nd, w);
    170         if (!nd.Member()) {
    171           ofst->SetProperties(kError, kError);
    172           return;
    173         }
    174         parent[arc.nextstate] = s;
    175         arc_parent[arc.nextstate] = arc;
    176         if (!enqueued[arc.nextstate]) {
    177           state_queue->Enqueue(arc.nextstate);
    178           enqueued[arc.nextstate] = true;
    179         } else {
    180           state_queue->Update(arc.nextstate);
    181         }
    182       }
    183     }
    184   }
    185 
    186   StateId s_p = kNoStateId, d_p = kNoStateId;
    187   for (StateId s = f_parent, d = kNoStateId;
    188        s != kNoStateId;
    189        d = s, s = parent[s]) {
    190     d_p = s_p;
    191     s_p = ofst->AddState();
    192     if (d == kNoStateId) {
    193       ofst->SetFinal(s_p, ifst.Final(f_parent));
    194     } else {
    195       arc_parent[d].nextstate = d_p;
    196       ofst->AddArc(s_p, arc_parent[d]);
    197     }
    198   }
    199   ofst->SetStart(s_p);
    200   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
    201   ofst->SetProperties(
    202       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
    203       kFstProperties);
    204 }
    205 
    206 
    207 template <class S, class W>
    208 class ShortestPathCompare {
    209  public:
    210   typedef S StateId;
    211   typedef W Weight;
    212   typedef pair<StateId, Weight> Pair;
    213 
    214   ShortestPathCompare(const vector<Pair>& pairs,
    215                       const vector<Weight>& distance,
    216                       StateId sfinal, float d)
    217       : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d)  {}
    218 
    219   bool operator()(const StateId x, const StateId y) const {
    220     const Pair &px = pairs_[x];
    221     const Pair &py = pairs_[y];
    222     Weight dx = px.first == superfinal_ ? Weight::One() :
    223         px.first < distance_.size() ? distance_[px.first] : Weight::Zero();
    224     Weight dy = py.first == superfinal_ ? Weight::One() :
    225         py.first < distance_.size() ? distance_[py.first] : Weight::Zero();
    226     Weight wx = Times(dx, px.second);
    227     Weight wy = Times(dy, py.second);
    228     // Penalize complete paths to ensure correct results with inexact weights.
    229     // This forms a strict weak order so long as ApproxEqual(a, b) =>
    230     // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
    231     if (px.first == superfinal_ && py.first != superfinal_) {
    232       return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
    233     } else if (py.first == superfinal_ && px.first != superfinal_) {
    234       return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
    235     } else {
    236       return less_(wy, wx);
    237     }
    238   }
    239 
    240  private:
    241   const vector<Pair> &pairs_;
    242   const vector<Weight> &distance_;
    243   StateId superfinal_;
    244   float delta_;
    245   NaturalLess<Weight> less_;
    246 };
    247 
    248 
    249 // N-Shortest-path algorithm: implements the core n-shortest path
    250 // algorithm. The output is built REVERSED. See below for versions with
    251 // more options and not reversed.
    252 //
    253 // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'.
    254 // 'distance' must contain the shortest distance from each state to a final
    255 // state in 'ifst'. 'delta' is the convergence delta.
    256 //
    257 // The n-shortest paths are the n-lowest weight paths w.r.t. the
    258 // natural semiring order. The single path that can be read from the
    259 // ith of at most n transitions leaving the initial state of 'ofst' is
    260 // the ith shortest path. Disregarding the initial state and initial
    261 // transitions, the n-shortest paths, in fact, form a tree rooted at
    262 // the single final state.
    263 //
    264 // The weights need to be left and right distributive (kSemiring) and
    265 // have the path (kPath) property.
    266 //
    267 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
    268 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
    269 // the shortest-distance algorithm. There are some issues with the
    270 // pseudo-code as written in the paper (viz., line 11).
    271 //
    272 // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and
    273 // and at any state in its expansion the values of distance vector need only
    274 // be defined at that time for the states that are known to exist.
    275 template<class Arc, class RevArc>
    276 void NShortestPath(const Fst<RevArc> &ifst,
    277                    MutableFst<Arc> *ofst,
    278                    const vector<typename Arc::Weight> &distance,
    279                    size_t n,
    280                    float delta = kDelta,
    281                    typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
    282                    typename Arc::StateId state_threshold = kNoStateId) {
    283   typedef typename Arc::StateId StateId;
    284   typedef typename Arc::Weight Weight;
    285   typedef pair<StateId, Weight> Pair;
    286   typedef typename RevArc::Weight RevWeight;
    287 
    288   if (n <= 0) return;
    289   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
    290     FSTERROR() << "NShortestPath: Weight needs to have the "
    291                  << "path property and be distributive: "
    292                  << Weight::Type();
    293     ofst->SetProperties(kError, kError);
    294     return;
    295   }
    296   ofst->DeleteStates();
    297   ofst->SetInputSymbols(ifst.InputSymbols());
    298   ofst->SetOutputSymbols(ifst.OutputSymbols());
    299   // Each state in 'ofst' corresponds to a path with weight w from the
    300   // initial state of 'ifst' to a state s in 'ifst', that can be
    301   // characterized by a pair (s,w).  The vector 'pairs' maps each
    302   // state in 'ofst' to the corresponding pair maps states in OFST to
    303   // the corresponding pair (s,w).
    304   vector<Pair> pairs;
    305   // The supefinal state is denoted by -1, 'compare' knows that the
    306   // distance from 'superfinal' to the final state is 'Weight::One()',
    307   // hence 'distance[superfinal]' is not needed.
    308   StateId superfinal = -1;
    309   ShortestPathCompare<StateId, Weight>
    310     compare(pairs, distance, superfinal, delta);
    311   vector<StateId> heap;
    312   // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst'
    313   // which corresponding pair contains 's' ,i.e. , it is number of
    314   // paths computed so far to 's'. Valid for 's == -1' (superfinal).
    315   vector<int> r;
    316   NaturalLess<Weight> less;
    317   if (ifst.Start() == kNoStateId ||
    318       distance.size() <= ifst.Start() ||
    319       distance[ifst.Start()] == Weight::Zero() ||
    320       less(weight_threshold, Weight::One()) ||
    321       state_threshold == 0) {
    322     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
    323     return;
    324   }
    325   ofst->SetStart(ofst->AddState());
    326   StateId final = ofst->AddState();
    327   ofst->SetFinal(final, Weight::One());
    328   while (pairs.size() <= final)
    329     pairs.push_back(Pair(kNoStateId, Weight::Zero()));
    330   pairs[final] = Pair(ifst.Start(), Weight::One());
    331   heap.push_back(final);
    332   Weight limit = Times(distance[ifst.Start()], weight_threshold);
    333 
    334   while (!heap.empty()) {
    335     pop_heap(heap.begin(), heap.end(), compare);
    336     StateId state = heap.back();
    337     Pair p = pairs[state];
    338     heap.pop_back();
    339     Weight d = p.first == superfinal ? Weight::One() :
    340         p.first < distance.size() ? distance[p.first] : Weight::Zero();
    341 
    342     if (less(limit, Times(d, p.second)) ||
    343         (state_threshold != kNoStateId &&
    344          ofst->NumStates() >= state_threshold))
    345       continue;
    346 
    347     while (r.size() <= p.first + 1) r.push_back(0);
    348     ++r[p.first + 1];
    349     if (p.first == superfinal)
    350       ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
    351     if ((p.first == superfinal) && (r[p.first + 1] == n)) break;
    352     if (r[p.first + 1] > n) continue;
    353     if (p.first == superfinal) continue;
    354 
    355     for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first);
    356          !aiter.Done();
    357          aiter.Next()) {
    358       const RevArc &rarc = aiter.Value();
    359       Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
    360       Weight w = Times(p.second, arc.weight);
    361       StateId next = ofst->AddState();
    362       pairs.push_back(Pair(arc.nextstate, w));
    363       arc.nextstate = state;
    364       ofst->AddArc(next, arc);
    365       heap.push_back(next);
    366       push_heap(heap.begin(), heap.end(), compare);
    367     }
    368 
    369     Weight finalw = ifst.Final(p.first).Reverse();
    370     if (finalw != Weight::Zero()) {
    371       Weight w = Times(p.second, finalw);
    372       StateId next = ofst->AddState();
    373       pairs.push_back(Pair(superfinal, w));
    374       ofst->AddArc(next, Arc(0, 0, finalw, state));
    375       heap.push_back(next);
    376       push_heap(heap.begin(), heap.end(), compare);
    377     }
    378   }
    379   Connect(ofst);
    380   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
    381   ofst->SetProperties(
    382       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
    383       kFstProperties);
    384 }
    385 
    386 
    387 // N-Shortest-path algorithm:  this version allow fine control
    388 // via the options argument. See below for a simpler interface.
    389 //
    390 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
    391 // the shortest distances from the source state to each state in
    392 // 'ifst'. 'opts' is used to specify options such as the number of
    393 // paths to return, whether they need to have distinct input
    394 // strings, the queue discipline, the arc filter and the convergence
    395 // delta.
    396 //
    397 // The n-shortest paths are the n-lowest weight paths w.r.t. the
    398 // natural semiring order. The single path that can be read from the
    399 // ith of at most n transitions leaving the initial state of 'ofst' is
    400 // the ith shortest path. Disregarding the initial state and initial
    401 // transitions, The n-shortest paths, in fact, form a tree rooted at
    402 // the single final state.
    403 
    404 // The weights need to be right distributive and have the path (kPath)
    405 // property. They need to be left distributive as well for nshortest
    406 // > 1.
    407 //
    408 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
    409 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
    410 // the shortest-distance algorithm. There are some issues with the
    411 // pseudo-code as written in the paper (viz., line 11).
    412 template<class Arc, class Queue, class ArcFilter>
    413 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
    414                   vector<typename Arc::Weight> *distance,
    415                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
    416   typedef typename Arc::StateId StateId;
    417   typedef typename Arc::Weight Weight;
    418   typedef ReverseArc<Arc> ReverseArc;
    419 
    420   size_t n = opts.nshortest;
    421   if (n == 1) {
    422     SingleShortestPath(ifst, ofst, distance, opts);
    423     return;
    424   }
    425   if (n <= 0) return;
    426   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
    427     FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the "
    428                << "path property and be distributive: "
    429                << Weight::Type();
    430     ofst->SetProperties(kError, kError);
    431     return;
    432   }
    433   if (!opts.has_distance) {
    434     ShortestDistance(ifst, distance, opts);
    435     if (distance->size() == 1 && !(*distance)[0].Member()) {
    436       ofst->SetProperties(kError, kError);
    437       return;
    438     }
    439   }
    440   // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is
    441   // the distance to the final state in 'rfst', 'ofst' is built as the
    442   // reverse of the tree of n-shortest path in 'rfst'.
    443   VectorFst<ReverseArc> rfst;
    444   Reverse(ifst, &rfst);
    445   Weight d = Weight::Zero();
    446   for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0);
    447        !aiter.Done(); aiter.Next()) {
    448     const ReverseArc &arc = aiter.Value();
    449     StateId s = arc.nextstate - 1;
    450     if (s < distance->size())
    451       d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s]));
    452   }
    453   distance->insert(distance->begin(), d);
    454 
    455   if (!opts.unique) {
    456     NShortestPath(rfst, ofst, *distance, n, opts.delta,
    457                   opts.weight_threshold, opts.state_threshold);
    458   } else {
    459     vector<Weight> ddistance;
    460     DeterminizeFstOptions<ReverseArc> dopts(opts.delta);
    461     DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts);
    462     NShortestPath(dfst, ofst, ddistance, n, opts.delta,
    463                   opts.weight_threshold, opts.state_threshold);
    464   }
    465   distance->erase(distance->begin());
    466 }
    467 
    468 
    469 // Shortest-path algorithm: simplified interface. See above for a
    470 // version that allows finer control.
    471 //
    472 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
    473 // discipline is automatically selected. When 'unique' == true, only
    474 // paths with distinct input labels are returned.
    475 //
    476 // The n-shortest paths are the n-lowest weight paths w.r.t. the
    477 // natural semiring order. The single path that can be read from the
    478 // ith of at most n transitions leaving the initial state of 'ofst' is
    479 // the ith best path.
    480 //
    481 // The weights need to be right distributive and have the path
    482 // (kPath) property.
    483 template<class Arc>
    484 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
    485                   size_t n = 1, bool unique = false,
    486                   bool first_path = false,
    487                   typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
    488                   typename Arc::StateId state_threshold = kNoStateId) {
    489   vector<typename Arc::Weight> distance;
    490   AnyArcFilter<Arc> arc_filter;
    491   AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
    492   ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
    493       AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false,
    494                                kDelta, first_path, weight_threshold,
    495                                state_threshold);
    496   ShortestPath(ifst, ofst, &distance, opts);
    497 }
    498 
    499 }  // namespace fst
    500 
    501 #endif  // FST_LIB_SHORTEST_PATH_H__
    502