Home | History | Annotate | Download | only in lib
      1 // shortest-path.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Functions to find shortest paths in an FST.
     19 
     20 #ifndef FST_LIB_SHORTEST_PATH_H__
     21 #define FST_LIB_SHORTEST_PATH_H__
     22 
     23 #include <functional>
     24 
     25 #include "fst/lib/cache.h"
     26 #include "fst/lib/queue.h"
     27 #include "fst/lib/shortest-distance.h"
     28 #include "fst/lib/test-properties.h"
     29 
     30 namespace fst {
     31 
     32 template <class Arc, class Queue, class ArcFilter>
     33 struct ShortestPathOptions
     34     : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
     35   typedef typename Arc::StateId StateId;
     36 
     37   size_t nshortest;      // return n-shortest paths
     38   bool unique;           // only return paths with distinct input strings
     39   bool has_distance;     // distance vector already contains the
     40                          // shortest distance from the initial state
     41 
     42   ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
     43                       bool hasdist = false, float d = kDelta)
     44       : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
     45         nshortest(n), unique(u), has_distance(hasdist)  {}
     46 };
     47 
     48 
     49 // Shortest-path algorithm: normally not called directly; prefer
     50 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
     51 // 'ifst'. 'distance' returns the shortest distances from the source
     52 // state to each state in 'ifst'. 'opts' is used to specify options
     53 // such as the queue discipline, the arc filter and delta.
     54 //
     55 // The shortest path is the lowest weight path w.r.t. the natural
     56 // semiring order.
     57 //
     58 // The weights need to be right distributive and have the path (kPath)
     59 // property.
     60 template<class Arc, class Queue, class ArcFilter>
     61 void SingleShortestPath(const Fst<Arc> &ifst,
     62                   MutableFst<Arc> *ofst,
     63                   vector<typename Arc::Weight> *distance,
     64                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
     65   typedef typename Arc::StateId StateId;
     66   typedef typename Arc::Weight Weight;
     67 
     68   ofst->DeleteStates();
     69   ofst->SetInputSymbols(ifst.InputSymbols());
     70   ofst->SetOutputSymbols(ifst.OutputSymbols());
     71 
     72   if (ifst.Start() == kNoStateId)
     73     return;
     74 
     75   vector<Weight> rdistance;
     76   vector<bool> enqueued;
     77   vector<StateId> parent;
     78   vector<Arc> arc_parent;
     79 
     80   Queue *state_queue = opts.state_queue;
     81   StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
     82   Weight f_distance = Weight::Zero();
     83   StateId f_parent = kNoStateId;
     84 
     85   distance->clear();
     86   state_queue->Clear();
     87   if (opts.nshortest != 1)
     88     LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath"
     89                << " instead";
     90   if ((Weight::Properties() & (kPath | kRightSemiring))
     91        != (kPath | kRightSemiring))
     92       LOG(FATAL) << "SingleShortestPath: Weight needs to have the path"
     93                  << " property and be right distributive: " << Weight::Type();
     94 
     95   while (distance->size() < source) {
     96     distance->push_back(Weight::Zero());
     97     enqueued.push_back(false);
     98     parent.push_back(kNoStateId);
     99     arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
    100   }
    101   distance->push_back(Weight::One());
    102   parent.push_back(kNoStateId);
    103   arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
    104   state_queue->Enqueue(source);
    105   enqueued.push_back(true);
    106 
    107   while (!state_queue->Empty()) {
    108     StateId s = state_queue->Head();
    109     state_queue->Dequeue();
    110     enqueued[s] = false;
    111     Weight sd = (*distance)[s];
    112     for (ArcIterator< Fst<Arc> > aiter(ifst, s);
    113          !aiter.Done();
    114          aiter.Next()) {
    115       const Arc &arc = aiter.Value();
    116       while (distance->size() <= arc.nextstate) {
    117         distance->push_back(Weight::Zero());
    118         enqueued.push_back(false);
    119         parent.push_back(kNoStateId);
    120         arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
    121                                  kNoStateId));
    122       }
    123       Weight &nd = (*distance)[arc.nextstate];
    124       Weight w = Times(sd, arc.weight);
    125       if (nd != Plus(nd, w)) {
    126         nd = Plus(nd, w);
    127         parent[arc.nextstate] = s;
    128         arc_parent[arc.nextstate] = arc;
    129         if (!enqueued[arc.nextstate]) {
    130           state_queue->Enqueue(arc.nextstate);
    131           enqueued[arc.nextstate] = true;
    132         } else {
    133           state_queue->Update(arc.nextstate);
    134         }
    135       }
    136     }
    137     if (ifst.Final(s) != Weight::Zero()) {
    138       Weight w = Times(sd, ifst.Final(s));
    139       if (f_distance != Plus(f_distance, w)) {
    140         f_distance = Plus(f_distance, w);
    141         f_parent = s;
    142       }
    143     }
    144   }
    145   (*distance)[source] = Weight::One();
    146   parent[source] = kNoStateId;
    147 
    148   StateId s_p = kNoStateId, d_p = kNoStateId;
    149   for (StateId s = f_parent, d = kNoStateId;
    150        s != kNoStateId;
    151        d = s, s = parent[s]) {
    152     enqueued[s] = true;
    153     d_p = s_p;
    154     s_p = ofst->AddState();
    155     if (d == kNoStateId) {
    156       ofst->SetFinal(s_p, ifst.Final(f_parent));
    157     } else {
    158       arc_parent[d].nextstate = d_p;
    159       ofst->AddArc(s_p, arc_parent[d]);
    160     }
    161   }
    162   ofst->SetStart(s_p);
    163 }
    164 
    165 
    166 template <class S, class W>
    167 class ShortestPathCompare {
    168  public:
    169   typedef S StateId;
    170   typedef W Weight;
    171   typedef pair<StateId, Weight> Pair;
    172 
    173   ShortestPathCompare(const vector<Pair>& pairs,
    174                       const vector<Weight>& distance,
    175                       StateId sfinal, float d)
    176       : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d)  {}
    177 
    178   bool operator()(const StateId x, const StateId y) const {
    179     const Pair &px = pairs_[x];
    180     const Pair &py = pairs_[y];
    181     Weight wx = Times(distance_[px.first], px.second);
    182     Weight wy = Times(distance_[py.first], py.second);
    183     // Penalize complete paths to ensure correct results with inexact weights.
    184     // This forms a strict weak order so long as ApproxEqual(a, b) =>
    185     // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
    186     if (px.first == superfinal_ && py.first != superfinal_) {
    187       return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
    188     } else if (py.first == superfinal_ && px.first != superfinal_) {
    189       return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
    190     } else {
    191       return less_(wy, wx);
    192     }
    193   }
    194 
    195  private:
    196   const vector<Pair> &pairs_;
    197   const vector<Weight> &distance_;
    198   StateId superfinal_;
    199   float delta_;
    200   NaturalLess<Weight> less_;
    201 };
    202 
    203 
    204 // N-Shortest-path algorithm:  this version allow fine control
    205 // via the otpions argument. See below for a simpler interface.
    206 //
    207 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
    208 // the shortest distances from the source state to each state in
    209 // 'ifst'. 'opts' is used to specify options such as the number of
    210 // paths to return, whether they need to have distinct input
    211 // strings, the queue discipline, the arc filter and the convergence
    212 // delta.
    213 //
    214 // The n-shortest paths are the n-lowest weight paths w.r.t. the
    215 // natural semiring order. The single path that can be
    216 // read from the ith of at most n transitions leaving the initial
    217 // state of 'ofst' is the ith shortest path.
    218 
    219 // The weights need to be right distributive and have the path (kPath)
    220 // property. They need to be left distributive as well for nshortest
    221 // > 1.
    222 //
    223 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
    224 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
    225 // the shortest-distance algorithm. There are some issues with the
    226 // pseudo-code as written in the paper (viz., line 11).
    227 template<class Arc, class Queue, class ArcFilter>
    228 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
    229                   vector<typename Arc::Weight> *distance,
    230                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
    231   typedef typename Arc::StateId StateId;
    232   typedef typename Arc::Weight Weight;
    233   typedef pair<StateId, Weight> Pair;
    234   typedef ReverseArc<Arc> ReverseArc;
    235   typedef typename ReverseArc::Weight ReverseWeight;
    236 
    237   size_t n = opts.nshortest;
    238 
    239   if (n == 1) {
    240     SingleShortestPath(ifst, ofst, distance, opts);
    241     return;
    242   }
    243   ofst->DeleteStates();
    244   ofst->SetInputSymbols(ifst.InputSymbols());
    245   ofst->SetOutputSymbols(ifst.OutputSymbols());
    246   if (n <= 0) return;
    247   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring))
    248     LOG(FATAL) << "ShortestPath: n-shortest: Weight needs to have the "
    249                  << "path property and be distributive: "
    250                  << Weight::Type();
    251   if (opts.unique)
    252     LOG(FATAL) << "ShortestPath: n-shortest-string algorithm not "
    253                << "currently implemented";
    254 
    255   // Algorithm works on the reverse of 'fst' : 'rfst' 'distance' is
    256   // the distance to the final state in 'rfst' 'ofst' is built as the
    257   // reverse of the tree of n-shortest path in 'rfst'.
    258 
    259   if (!opts.has_distance)
    260     ShortestDistance(ifst, distance, opts);
    261   VectorFst<ReverseArc> rfst;
    262   Reverse(ifst, &rfst);
    263   distance->insert(distance->begin(), Weight::One());
    264   while (distance->size() < rfst.NumStates())
    265     distance->push_back(Weight::Zero());
    266 
    267 
    268   // Each state in 'ofst' corresponds to a path with weight w from the
    269   // initial state of 'rfst' to a state s in 'rfst', that can be
    270   // characterized by a pair (s,w).  The vector 'pairs' maps each
    271   // state in 'ofst' to the corresponding pair maps states in OFST to
    272   // the corresponding pair (s,w).
    273   vector<Pair> pairs;
    274   // 'r[s]', 's' state in 'fst', is the number of states in 'ofst'
    275   // which corresponding pair contains 's' ,i.e. , it is number of
    276   // paths computed so far to 's'.
    277   StateId superfinal = distance->size();  // superfinal must be handled
    278   distance->push_back(Weight::One());     // differently when unique=true
    279   ShortestPathCompare<StateId, Weight>
    280     compare(pairs, *distance, superfinal, opts.delta);
    281   vector<StateId> heap;
    282   vector<int> r;
    283   while (r.size() < distance->size())
    284     r.push_back(0);
    285   ofst->SetStart(ofst->AddState());
    286   StateId final = ofst->AddState();
    287   ofst->SetFinal(final, Weight::One());
    288   while (pairs.size() <= final)
    289     pairs.push_back(Pair(kNoStateId, Weight::Zero()));
    290   pairs[final] = Pair(rfst.Start(), Weight::One());
    291   heap.push_back(final);
    292 
    293   while (!heap.empty()) {
    294     pop_heap(heap.begin(), heap.end(), compare);
    295     StateId state = heap.back();
    296     Pair p = pairs[state];
    297     heap.pop_back();
    298 
    299     ++r[p.first];
    300     if (p.first == superfinal)
    301       ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
    302     if ((p.first == superfinal) &&  (r[p.first] == n)) break;
    303     if (r[p.first] > n) continue;
    304     if (p.first == superfinal)
    305       continue;
    306 
    307     for (ArcIterator< Fst<ReverseArc> > aiter(rfst, p.first);
    308          !aiter.Done();
    309          aiter.Next()) {
    310       const ReverseArc &rarc = aiter.Value();
    311       Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
    312       Weight w = Times(p.second, arc.weight);
    313       StateId next = ofst->AddState();
    314       pairs.push_back(Pair(arc.nextstate, w));
    315       arc.nextstate = state;
    316       ofst->AddArc(next, arc);
    317       heap.push_back(next);
    318       push_heap(heap.begin(), heap.end(), compare);
    319     }
    320 
    321     Weight finalw = rfst.Final(p.first).Reverse();
    322     if (finalw != Weight::Zero()) {
    323       Weight w = Times(p.second, finalw);
    324       StateId next = ofst->AddState();
    325       pairs.push_back(Pair(superfinal, w));
    326       ofst->AddArc(next, Arc(0, 0, finalw, state));
    327       heap.push_back(next);
    328       push_heap(heap.begin(), heap.end(), compare);
    329     }
    330   }
    331   Connect(ofst);
    332   distance->erase(distance->begin());
    333   distance->pop_back();
    334 }
    335 
    336 // Shortest-path algorithm: simplified interface. See above for a
    337 // version that allows finer control.
    338 
    339 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
    340 // discipline is automatically selected. When 'unique' == true, only
    341 // paths with distinct input labels are returned.
    342 //
    343 // The n-shortest paths are the n-lowest weight paths w.r.t. the
    344 // natural semiring order. The single path that can be read from the
    345 // ith of at most n transitions leaving the initial state of 'ofst' is
    346 // the ith best path.
    347 //
    348 // The weights need to be right distributive and have the path
    349 // (kPath) property.
    350 template<class Arc>
    351 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
    352                   size_t n = 1, bool unique = false) {
    353   vector<typename Arc::Weight> distance;
    354   AnyArcFilter<Arc> arc_filter;
    355   AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
    356   ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
    357     AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique);
    358   ShortestPath(ifst, ofst, &distance, opts);
    359 }
    360 
    361 }  // namespace fst
    362 
    363 #endif  // FST_LIB_SHORTEST_PATH_H__
    364