Home | History | Annotate | Download | only in lib
      1 // shortest-distance.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Functions and classes to find shortest distance in an FST.
     19 
     20 #ifndef FST_LIB_SHORTEST_DISTANCE_H__
     21 #define FST_LIB_SHORTEST_DISTANCE_H__
     22 
     23 #include <deque>
     24 
     25 #include "fst/lib/arcfilter.h"
     26 #include "fst/lib/cache.h"
     27 #include "fst/lib/queue.h"
     28 #include "fst/lib/reverse.h"
     29 #include "fst/lib/test-properties.h"
     30 
     31 namespace fst {
     32 
     33 template <class Arc, class Queue, class ArcFilter>
     34 struct ShortestDistanceOptions {
     35   typedef typename Arc::StateId StateId;
     36 
     37   Queue *state_queue;    // Queue discipline used; owned by caller
     38   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
     39   StateId source;        // If kNoStateId, use the Fst's initial state
     40   float delta;           // Determines the degree of convergence required
     41 
     42   ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
     43                           float d = kDelta)
     44       : state_queue(q), arc_filter(filt), source(src), delta(d) {}
     45 };
     46 
     47 
     48 // Computation state of the shortest-distance algorithm. Reusable
     49 // information is maintained across calls to member function
     50 // ShortestDistance(source) when 'retain' is true for improved
     51 // efficiency when calling multiple times from different source states
     52 // (e.g., in epsilon removal). Vector 'distance' should not be
     53 // modified by the user between these calls.
     54 template<class Arc, class Queue, class ArcFilter>
     55 class ShortestDistanceState {
     56  public:
     57   typedef typename Arc::StateId StateId;
     58   typedef typename Arc::Weight Weight;
     59 
     60   ShortestDistanceState(
     61       const Fst<Arc> &fst,
     62       vector<Weight> *distance,
     63       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
     64       bool retain)
     65       : fst_(fst.Copy()), distance_(distance), state_queue_(opts.state_queue),
     66         arc_filter_(opts.arc_filter),
     67         delta_(opts.delta), retain_(retain) {
     68     distance_->clear();
     69   }
     70 
     71   ~ShortestDistanceState() {
     72     delete fst_;
     73   }
     74 
     75   void ShortestDistance(StateId source);
     76 
     77  private:
     78   const Fst<Arc> *fst_;
     79   vector<Weight> *distance_;
     80   Queue *state_queue_;
     81   ArcFilter arc_filter_;
     82   float delta_;
     83   bool retain_;                  // Retain and reuse information across calls
     84 
     85   vector<Weight> rdistance_;    // Relaxation distance.
     86   vector<bool> enqueued_;       // Is state enqueued?
     87   vector<StateId> sources_;     // Source state for ith state in 'distance_',
     88                                 //  'rdistance_', and 'enqueued_' if retained.
     89 };
     90 
     91 // Compute the shortest distance. If 'source' is kNoStateId, use
     92 // the initial state of the Fst.
     93 template <class Arc, class Queue, class ArcFilter>
     94 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
     95     StateId source) {
     96   if (fst_->Start() == kNoStateId)
     97     return;
     98 
     99   if (!(Weight::Properties() & kRightSemiring))
    100     LOG(FATAL) << "ShortestDistance: Weight needs to be right distributive: "
    101                << Weight::Type();
    102 
    103   state_queue_->Clear();
    104 
    105   if (!retain_) {
    106     distance_->clear();
    107     rdistance_.clear();
    108     enqueued_.clear();
    109   }
    110 
    111   if (source == kNoStateId)
    112     source = fst_->Start();
    113 
    114   while ((StateId)distance_->size() <= source) {
    115     distance_->push_back(Weight::Zero());
    116     rdistance_.push_back(Weight::Zero());
    117     enqueued_.push_back(false);
    118   }
    119   if (retain_) {
    120     while ((StateId)sources_.size() <= source)
    121       sources_.push_back(kNoStateId);
    122     sources_[source] = source;
    123   }
    124   (*distance_)[source] = Weight::One();
    125   rdistance_[source] = Weight::One();
    126   enqueued_[source] = true;
    127 
    128   state_queue_->Enqueue(source);
    129 
    130   while (!state_queue_->Empty()) {
    131     StateId s = state_queue_->Head();
    132     state_queue_->Dequeue();
    133     while ((StateId)distance_->size() <= s) {
    134       distance_->push_back(Weight::Zero());
    135       rdistance_.push_back(Weight::Zero());
    136       enqueued_.push_back(false);
    137     }
    138     enqueued_[s] = false;
    139     Weight r = rdistance_[s];
    140     rdistance_[s] = Weight::Zero();
    141     for (ArcIterator< Fst<Arc> > aiter(*fst_, s);
    142          !aiter.Done();
    143          aiter.Next()) {
    144       const Arc &arc = aiter.Value();
    145       if (!arc_filter_(arc) || arc.weight == Weight::Zero())
    146         continue;
    147       while ((StateId)distance_->size() <= arc.nextstate) {
    148         distance_->push_back(Weight::Zero());
    149         rdistance_.push_back(Weight::Zero());
    150         enqueued_.push_back(false);
    151       }
    152       if (retain_) {
    153         while ((StateId)sources_.size() <= arc.nextstate)
    154           sources_.push_back(kNoStateId);
    155         if (sources_[arc.nextstate] != source) {
    156           (*distance_)[arc.nextstate] = Weight::Zero();
    157           rdistance_[arc.nextstate] = Weight::Zero();
    158           enqueued_[arc.nextstate] = false;
    159           sources_[arc.nextstate] = source;
    160         }
    161       }
    162       Weight &nd = (*distance_)[arc.nextstate];
    163       Weight &nr = rdistance_[arc.nextstate];
    164       Weight w = Times(r, arc.weight);
    165       if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
    166         nd = Plus(nd, w);
    167         nr = Plus(nr, w);
    168         if (!enqueued_[arc.nextstate]) {
    169           state_queue_->Enqueue(arc.nextstate);
    170           enqueued_[arc.nextstate] = true;
    171         } else {
    172           state_queue_->Update(arc.nextstate);
    173         }
    174       }
    175     }
    176   }
    177 }
    178 
    179 
    180 // Shortest-distance algorithm: this version allows fine control
    181 // via the options argument. See below for a simpler interface.
    182 //
    183 // This computes the shortest distance from the 'opts.source' state to
    184 // each visited state S and stores the value in the 'distance' vector.
    185 // An unvisited state S has distance Zero(), which will be stored in
    186 // the 'distance' vector if S is less than the maximum visited state.
    187 // The state queue discipline, arc filter, and convergence delta are
    188 // taken in the options argument.
    189 
    190 // The weights must must be right distributive and k-closed (i.e., 1 +
    191 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
    192 //
    193 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
    194 // Shortest-Distance Problems", Journal of Automata, Languages and
    195 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
    196 // depends on the properties of the semiring and the queue discipline
    197 // used. Refer to the paper for more details.
    198 template<class Arc, class Queue, class ArcFilter>
    199 void ShortestDistance(
    200     const Fst<Arc> &fst,
    201     vector<typename Arc::Weight> *distance,
    202     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
    203 
    204   ShortestDistanceState<Arc, Queue, ArcFilter>
    205     sd_state(fst, distance, opts, false);
    206   sd_state.ShortestDistance(opts.source);
    207 }
    208 
    209 // Shortest-distance algorithm: simplified interface. See above for a
    210 // version that allows finer control.
    211 //
    212 // If 'reverse' is false, this computes the shortest distance from the
    213 // initial state to each state S and stores the value in the
    214 // 'distance' vector. If 'reverse' is true, this computes the shortest
    215 // distance from each state to the final states.  An unvisited state S
    216 // has distance Zero(), which will be stored in the 'distance' vector
    217 // if S is less than the maximum visited state.  The state queue
    218 // discipline is automatically-selected.
    219 //
    220 // The weights must must be right (left) distributive if reverse is
    221 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
    222 // x + x^2 + ... + x^k).
    223 //
    224 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
    225 // Shortest-Distance Problems", Journal of Automata, Languages and
    226 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
    227 // depends on the properties of the semiring and the queue discipline
    228 // used. Refer to the paper for more details.
    229 template<class Arc>
    230 void ShortestDistance(const Fst<Arc> &fst,
    231                       vector<typename Arc::Weight> *distance,
    232                       bool reverse = false) {
    233   typedef typename Arc::StateId StateId;
    234   typedef typename Arc::Weight Weight;
    235 
    236   if (!reverse) {
    237     AnyArcFilter<Arc> arc_filter;
    238     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
    239     ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
    240       opts(&state_queue, arc_filter);
    241     ShortestDistance(fst, distance, opts);
    242   } else {
    243     typedef ReverseArc<Arc> ReverseArc;
    244     typedef typename ReverseArc::Weight ReverseWeight;
    245     AnyArcFilter<ReverseArc> rarc_filter;
    246     VectorFst<ReverseArc> rfst;
    247     Reverse(fst, &rfst);
    248     vector<ReverseWeight> rdistance;
    249     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
    250     ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
    251       AnyArcFilter<ReverseArc> >
    252       ropts(&state_queue, rarc_filter);
    253     ShortestDistance(rfst, &rdistance, ropts);
    254     distance->clear();
    255     while (distance->size() < rdistance.size() - 1)
    256       distance->push_back(rdistance[distance->size() + 1].Reverse());
    257   }
    258 }
    259 
    260 }  // namespace fst
    261 
    262 #endif  // FST_LIB_SHORTEST_DISTANCE_H__
    263