Home | History | Annotate | Download | only in fst
      1 // shortest-distance.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Functions and classes to find shortest distance in an FST.
     20 
     21 #ifndef FST_LIB_SHORTEST_DISTANCE_H__
     22 #define FST_LIB_SHORTEST_DISTANCE_H__
     23 
     24 #include <deque>
     25 using std::deque;
     26 #include <vector>
     27 using std::vector;
     28 
     29 #include <fst/arcfilter.h>
     30 #include <fst/cache.h>
     31 #include <fst/queue.h>
     32 #include <fst/reverse.h>
     33 #include <fst/test-properties.h>
     34 
     35 
     36 namespace fst {
     37 
     38 template <class Arc, class Queue, class ArcFilter>
     39 struct ShortestDistanceOptions {
     40   typedef typename Arc::StateId StateId;
     41 
     42   Queue *state_queue;    // Queue discipline used; owned by caller
     43   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
     44   StateId source;        // If kNoStateId, use the Fst's initial state
     45   float delta;           // Determines the degree of convergence required
     46   bool first_path;       // For a semiring with the path property (o.w.
     47                          // undefined), compute the shortest-distances along
     48                          // along the first path to a final state found
     49                          // by the algorithm. That path is the shortest-path
     50                          // only if the FST has a unique final state (or all
     51                          // the final states have the same final weight), the
     52                          // queue discipline is shortest-first and all the
     53                          // weights in the FST are between One() and Zero()
     54                          // according to NaturalLess.
     55 
     56   ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
     57                           float d = kDelta)
     58       : state_queue(q), arc_filter(filt), source(src), delta(d),
     59         first_path(false) {}
     60 };
     61 
     62 
     63 // Computation state of the shortest-distance algorithm. Reusable
     64 // information is maintained across calls to member function
     65 // ShortestDistance(source) when 'retain' is true for improved
     66 // efficiency when calling multiple times from different source states
     67 // (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
     68 // may not be freed before this class. Vector 'distance' should not be
     69 // modified by the user between these calls.
     70 // The Error() method returns true if an error was encountered.
     71 template<class Arc, class Queue, class ArcFilter>
     72 class ShortestDistanceState {
     73  public:
     74   typedef typename Arc::StateId StateId;
     75   typedef typename Arc::Weight Weight;
     76 
     77   ShortestDistanceState(
     78       const Fst<Arc> &fst,
     79       vector<Weight> *distance,
     80       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
     81       bool retain)
     82       : fst_(fst), distance_(distance), state_queue_(opts.state_queue),
     83         arc_filter_(opts.arc_filter), delta_(opts.delta),
     84         first_path_(opts.first_path), retain_(retain), source_id_(0),
     85         error_(false) {
     86     distance_->clear();
     87   }
     88 
     89   ~ShortestDistanceState() {}
     90 
     91   void ShortestDistance(StateId source);
     92 
     93   bool Error() const { return error_; }
     94 
     95  private:
     96   const Fst<Arc> &fst_;
     97   vector<Weight> *distance_;
     98   Queue *state_queue_;
     99   ArcFilter arc_filter_;
    100   float delta_;
    101   bool first_path_;
    102   bool retain_;               // Retain and reuse information across calls
    103 
    104   vector<Weight> rdistance_;  // Relaxation distance.
    105   vector<bool> enqueued_;     // Is state enqueued?
    106   vector<StateId> sources_;   // Source ID for ith state in 'distance_',
    107                               //  'rdistance_', and 'enqueued_' if retained.
    108   StateId source_id_;         // Unique ID characterizing each call to SD
    109 
    110   bool error_;
    111 };
    112 
    113 // Compute the shortest distance. If 'source' is kNoStateId, use
    114 // the initial state of the Fst.
    115 template <class Arc, class Queue, class ArcFilter>
    116 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
    117     StateId source) {
    118   if (fst_.Start() == kNoStateId) {
    119     if (fst_.Properties(kError, false)) error_ = true;
    120     return;
    121   }
    122 
    123   if (!(Weight::Properties() & kRightSemiring)) {
    124     FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
    125                << Weight::Type();
    126     error_ = true;
    127     return;
    128   }
    129 
    130   if (first_path_ && !(Weight::Properties() & kPath)) {
    131     FSTERROR() << "ShortestDistance: first_path option disallowed when "
    132                << "Weight does not have the path property: "
    133                << Weight::Type();
    134     error_ = true;
    135     return;
    136   }
    137 
    138   state_queue_->Clear();
    139 
    140   if (!retain_) {
    141     distance_->clear();
    142     rdistance_.clear();
    143     enqueued_.clear();
    144   }
    145 
    146   if (source == kNoStateId)
    147     source = fst_.Start();
    148 
    149   while (distance_->size() <= source) {
    150     distance_->push_back(Weight::Zero());
    151     rdistance_.push_back(Weight::Zero());
    152     enqueued_.push_back(false);
    153   }
    154   if (retain_) {
    155     while (sources_.size() <= source)
    156       sources_.push_back(kNoStateId);
    157     sources_[source] = source_id_;
    158   }
    159   (*distance_)[source] = Weight::One();
    160   rdistance_[source] = Weight::One();
    161   enqueued_[source] = true;
    162 
    163   state_queue_->Enqueue(source);
    164 
    165   while (!state_queue_->Empty()) {
    166     StateId s = state_queue_->Head();
    167     state_queue_->Dequeue();
    168     while (distance_->size() <= s) {
    169       distance_->push_back(Weight::Zero());
    170       rdistance_.push_back(Weight::Zero());
    171       enqueued_.push_back(false);
    172     }
    173     if (first_path_ && (fst_.Final(s) != Weight::Zero()))
    174       break;
    175     enqueued_[s] = false;
    176     Weight r = rdistance_[s];
    177     rdistance_[s] = Weight::Zero();
    178     for (ArcIterator< Fst<Arc> > aiter(fst_, s);
    179          !aiter.Done();
    180          aiter.Next()) {
    181       const Arc &arc = aiter.Value();
    182       if (!arc_filter_(arc))
    183         continue;
    184       while (distance_->size() <= arc.nextstate) {
    185         distance_->push_back(Weight::Zero());
    186         rdistance_.push_back(Weight::Zero());
    187         enqueued_.push_back(false);
    188       }
    189       if (retain_) {
    190         while (sources_.size() <= arc.nextstate)
    191           sources_.push_back(kNoStateId);
    192         if (sources_[arc.nextstate] != source_id_) {
    193           (*distance_)[arc.nextstate] = Weight::Zero();
    194           rdistance_[arc.nextstate] = Weight::Zero();
    195           enqueued_[arc.nextstate] = false;
    196           sources_[arc.nextstate] = source_id_;
    197         }
    198       }
    199       Weight &nd = (*distance_)[arc.nextstate];
    200       Weight &nr = rdistance_[arc.nextstate];
    201       Weight w = Times(r, arc.weight);
    202       if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
    203         nd = Plus(nd, w);
    204         nr = Plus(nr, w);
    205         if (!nd.Member() || !nr.Member()) {
    206           error_ = true;
    207           return;
    208         }
    209         if (!enqueued_[arc.nextstate]) {
    210           state_queue_->Enqueue(arc.nextstate);
    211           enqueued_[arc.nextstate] = true;
    212         } else {
    213           state_queue_->Update(arc.nextstate);
    214         }
    215       }
    216     }
    217   }
    218   ++source_id_;
    219   if (fst_.Properties(kError, false)) error_ = true;
    220 }
    221 
    222 
    223 // Shortest-distance algorithm: this version allows fine control
    224 // via the options argument. See below for a simpler interface.
    225 //
    226 // This computes the shortest distance from the 'opts.source' state to
    227 // each visited state S and stores the value in the 'distance' vector.
    228 // An unvisited state S has distance Zero(), which will be stored in
    229 // the 'distance' vector if S is less than the maximum visited state.
    230 // The state queue discipline, arc filter, and convergence delta are
    231 // taken in the options argument.
    232 // The 'distance' vector will contain a unique element for which
    233 // Member() is false if an error was encountered.
    234 //
    235 // The weights must must be right distributive and k-closed (i.e., 1 +
    236 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
    237 //
    238 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
    239 // Shortest-Distance Problems", Journal of Automata, Languages and
    240 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
    241 // depends on the properties of the semiring and the queue discipline
    242 // used. Refer to the paper for more details.
    243 template<class Arc, class Queue, class ArcFilter>
    244 void ShortestDistance(
    245     const Fst<Arc> &fst,
    246     vector<typename Arc::Weight> *distance,
    247     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
    248 
    249   ShortestDistanceState<Arc, Queue, ArcFilter>
    250     sd_state(fst, distance, opts, false);
    251   sd_state.ShortestDistance(opts.source);
    252   if (sd_state.Error()) {
    253     distance->clear();
    254     distance->resize(1, Arc::Weight::NoWeight());
    255   }
    256 }
    257 
    258 // Shortest-distance algorithm: simplified interface. See above for a
    259 // version that allows finer control.
    260 //
    261 // If 'reverse' is false, this computes the shortest distance from the
    262 // initial state to each state S and stores the value in the
    263 // 'distance' vector. If 'reverse' is true, this computes the shortest
    264 // distance from each state to the final states.  An unvisited state S
    265 // has distance Zero(), which will be stored in the 'distance' vector
    266 // if S is less than the maximum visited state.  The state queue
    267 // discipline is automatically-selected.
    268 // The 'distance' vector will contain a unique element for which
    269 // Member() is false if an error was encountered.
    270 //
    271 // The weights must must be right (left) distributive if reverse is
    272 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
    273 // x + x^2 + ... + x^k).
    274 //
    275 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
    276 // Shortest-Distance Problems", Journal of Automata, Languages and
    277 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
    278 // depends on the properties of the semiring and the queue discipline
    279 // used. Refer to the paper for more details.
    280 template<class Arc>
    281 void ShortestDistance(const Fst<Arc> &fst,
    282                       vector<typename Arc::Weight> *distance,
    283                       bool reverse = false,
    284                       float delta = kDelta) {
    285   typedef typename Arc::StateId StateId;
    286   typedef typename Arc::Weight Weight;
    287 
    288   if (!reverse) {
    289     AnyArcFilter<Arc> arc_filter;
    290     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
    291     ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
    292       opts(&state_queue, arc_filter);
    293     opts.delta = delta;
    294     ShortestDistance(fst, distance, opts);
    295   } else {
    296     typedef ReverseArc<Arc> ReverseArc;
    297     typedef typename ReverseArc::Weight ReverseWeight;
    298     AnyArcFilter<ReverseArc> rarc_filter;
    299     VectorFst<ReverseArc> rfst;
    300     Reverse(fst, &rfst);
    301     vector<ReverseWeight> rdistance;
    302     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
    303     ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
    304       AnyArcFilter<ReverseArc> >
    305       ropts(&state_queue, rarc_filter);
    306     ropts.delta = delta;
    307     ShortestDistance(rfst, &rdistance, ropts);
    308     distance->clear();
    309     if (rdistance.size() == 1 && !rdistance[0].Member()) {
    310       distance->resize(1, Arc::Weight::NoWeight());
    311       return;
    312     }
    313     while (distance->size() < rdistance.size() - 1)
    314       distance->push_back(rdistance[distance->size() + 1].Reverse());
    315   }
    316 }
    317 
    318 
    319 // Return the sum of the weight of all successful paths in an FST, i.e.,
    320 // the shortest-distance from the initial state to the final states.
    321 // Returns a weight such that Member() is false if an error was encountered.
    322 template <class Arc>
    323 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) {
    324   typedef typename Arc::Weight Weight;
    325   typedef typename Arc::StateId StateId;
    326   vector<Weight> distance;
    327   if (Weight::Properties() & kRightSemiring) {
    328     ShortestDistance(fst, &distance, false, delta);
    329     if (distance.size() == 1 && !distance[0].Member())
    330       return Arc::Weight::NoWeight();
    331     Weight sum = Weight::Zero();
    332     for (StateId s = 0; s < distance.size(); ++s)
    333       sum = Plus(sum, Times(distance[s], fst.Final(s)));
    334     return sum;
    335   } else {
    336     ShortestDistance(fst, &distance, true, delta);
    337     StateId s = fst.Start();
    338     if (distance.size() == 1 && !distance[0].Member())
    339       return Arc::Weight::NoWeight();
    340     return s != kNoStateId && s < distance.size() ?
    341         distance[s] : Weight::Zero();
    342   }
    343 }
    344 
    345 
    346 }  // namespace fst
    347 
    348 #endif  // FST_LIB_SHORTEST_DISTANCE_H__
    349