1 // shortest-distance.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Functions and classes to find shortest distance in an FST. 19 20 #ifndef FST_LIB_SHORTEST_DISTANCE_H__ 21 #define FST_LIB_SHORTEST_DISTANCE_H__ 22 23 #include <deque> 24 25 #include "fst/lib/arcfilter.h" 26 #include "fst/lib/cache.h" 27 #include "fst/lib/queue.h" 28 #include "fst/lib/reverse.h" 29 #include "fst/lib/test-properties.h" 30 31 namespace fst { 32 33 template <class Arc, class Queue, class ArcFilter> 34 struct ShortestDistanceOptions { 35 typedef typename Arc::StateId StateId; 36 37 Queue *state_queue; // Queue discipline used; owned by caller 38 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) 39 StateId source; // If kNoStateId, use the Fst's initial state 40 float delta; // Determines the degree of convergence required 41 42 ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, 43 float d = kDelta) 44 : state_queue(q), arc_filter(filt), source(src), delta(d) {} 45 }; 46 47 48 // Computation state of the shortest-distance algorithm. Reusable 49 // information is maintained across calls to member function 50 // ShortestDistance(source) when 'retain' is true for improved 51 // efficiency when calling multiple times from different source states 52 // (e.g., in epsilon removal). Vector 'distance' should not be 53 // modified by the user between these calls. 54 template<class Arc, class Queue, class ArcFilter> 55 class ShortestDistanceState { 56 public: 57 typedef typename Arc::StateId StateId; 58 typedef typename Arc::Weight Weight; 59 60 ShortestDistanceState( 61 const Fst<Arc> &fst, 62 vector<Weight> *distance, 63 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, 64 bool retain) 65 : fst_(fst.Copy()), distance_(distance), state_queue_(opts.state_queue), 66 arc_filter_(opts.arc_filter), 67 delta_(opts.delta), retain_(retain) { 68 distance_->clear(); 69 } 70 71 ~ShortestDistanceState() { 72 delete fst_; 73 } 74 75 void ShortestDistance(StateId source); 76 77 private: 78 const Fst<Arc> *fst_; 79 vector<Weight> *distance_; 80 Queue *state_queue_; 81 ArcFilter arc_filter_; 82 float delta_; 83 bool retain_; // Retain and reuse information across calls 84 85 vector<Weight> rdistance_; // Relaxation distance. 86 vector<bool> enqueued_; // Is state enqueued? 87 vector<StateId> sources_; // Source state for ith state in 'distance_', 88 // 'rdistance_', and 'enqueued_' if retained. 89 }; 90 91 // Compute the shortest distance. If 'source' is kNoStateId, use 92 // the initial state of the Fst. 93 template <class Arc, class Queue, class ArcFilter> 94 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( 95 StateId source) { 96 if (fst_->Start() == kNoStateId) 97 return; 98 99 if (!(Weight::Properties() & kRightSemiring)) 100 LOG(FATAL) << "ShortestDistance: Weight needs to be right distributive: " 101 << Weight::Type(); 102 103 state_queue_->Clear(); 104 105 if (!retain_) { 106 distance_->clear(); 107 rdistance_.clear(); 108 enqueued_.clear(); 109 } 110 111 if (source == kNoStateId) 112 source = fst_->Start(); 113 114 while ((StateId)distance_->size() <= source) { 115 distance_->push_back(Weight::Zero()); 116 rdistance_.push_back(Weight::Zero()); 117 enqueued_.push_back(false); 118 } 119 if (retain_) { 120 while ((StateId)sources_.size() <= source) 121 sources_.push_back(kNoStateId); 122 sources_[source] = source; 123 } 124 (*distance_)[source] = Weight::One(); 125 rdistance_[source] = Weight::One(); 126 enqueued_[source] = true; 127 128 state_queue_->Enqueue(source); 129 130 while (!state_queue_->Empty()) { 131 StateId s = state_queue_->Head(); 132 state_queue_->Dequeue(); 133 while ((StateId)distance_->size() <= s) { 134 distance_->push_back(Weight::Zero()); 135 rdistance_.push_back(Weight::Zero()); 136 enqueued_.push_back(false); 137 } 138 enqueued_[s] = false; 139 Weight r = rdistance_[s]; 140 rdistance_[s] = Weight::Zero(); 141 for (ArcIterator< Fst<Arc> > aiter(*fst_, s); 142 !aiter.Done(); 143 aiter.Next()) { 144 const Arc &arc = aiter.Value(); 145 if (!arc_filter_(arc) || arc.weight == Weight::Zero()) 146 continue; 147 while ((StateId)distance_->size() <= arc.nextstate) { 148 distance_->push_back(Weight::Zero()); 149 rdistance_.push_back(Weight::Zero()); 150 enqueued_.push_back(false); 151 } 152 if (retain_) { 153 while ((StateId)sources_.size() <= arc.nextstate) 154 sources_.push_back(kNoStateId); 155 if (sources_[arc.nextstate] != source) { 156 (*distance_)[arc.nextstate] = Weight::Zero(); 157 rdistance_[arc.nextstate] = Weight::Zero(); 158 enqueued_[arc.nextstate] = false; 159 sources_[arc.nextstate] = source; 160 } 161 } 162 Weight &nd = (*distance_)[arc.nextstate]; 163 Weight &nr = rdistance_[arc.nextstate]; 164 Weight w = Times(r, arc.weight); 165 if (!ApproxEqual(nd, Plus(nd, w), delta_)) { 166 nd = Plus(nd, w); 167 nr = Plus(nr, w); 168 if (!enqueued_[arc.nextstate]) { 169 state_queue_->Enqueue(arc.nextstate); 170 enqueued_[arc.nextstate] = true; 171 } else { 172 state_queue_->Update(arc.nextstate); 173 } 174 } 175 } 176 } 177 } 178 179 180 // Shortest-distance algorithm: this version allows fine control 181 // via the options argument. See below for a simpler interface. 182 // 183 // This computes the shortest distance from the 'opts.source' state to 184 // each visited state S and stores the value in the 'distance' vector. 185 // An unvisited state S has distance Zero(), which will be stored in 186 // the 'distance' vector if S is less than the maximum visited state. 187 // The state queue discipline, arc filter, and convergence delta are 188 // taken in the options argument. 189 190 // The weights must must be right distributive and k-closed (i.e., 1 + 191 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). 192 // 193 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for 194 // Shortest-Distance Problems", Journal of Automata, Languages and 195 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm 196 // depends on the properties of the semiring and the queue discipline 197 // used. Refer to the paper for more details. 198 template<class Arc, class Queue, class ArcFilter> 199 void ShortestDistance( 200 const Fst<Arc> &fst, 201 vector<typename Arc::Weight> *distance, 202 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { 203 204 ShortestDistanceState<Arc, Queue, ArcFilter> 205 sd_state(fst, distance, opts, false); 206 sd_state.ShortestDistance(opts.source); 207 } 208 209 // Shortest-distance algorithm: simplified interface. See above for a 210 // version that allows finer control. 211 // 212 // If 'reverse' is false, this computes the shortest distance from the 213 // initial state to each state S and stores the value in the 214 // 'distance' vector. If 'reverse' is true, this computes the shortest 215 // distance from each state to the final states. An unvisited state S 216 // has distance Zero(), which will be stored in the 'distance' vector 217 // if S is less than the maximum visited state. The state queue 218 // discipline is automatically-selected. 219 // 220 // The weights must must be right (left) distributive if reverse is 221 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + 222 // x + x^2 + ... + x^k). 223 // 224 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for 225 // Shortest-Distance Problems", Journal of Automata, Languages and 226 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm 227 // depends on the properties of the semiring and the queue discipline 228 // used. Refer to the paper for more details. 229 template<class Arc> 230 void ShortestDistance(const Fst<Arc> &fst, 231 vector<typename Arc::Weight> *distance, 232 bool reverse = false) { 233 typedef typename Arc::StateId StateId; 234 typedef typename Arc::Weight Weight; 235 236 if (!reverse) { 237 AnyArcFilter<Arc> arc_filter; 238 AutoQueue<StateId> state_queue(fst, distance, arc_filter); 239 ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > 240 opts(&state_queue, arc_filter); 241 ShortestDistance(fst, distance, opts); 242 } else { 243 typedef ReverseArc<Arc> ReverseArc; 244 typedef typename ReverseArc::Weight ReverseWeight; 245 AnyArcFilter<ReverseArc> rarc_filter; 246 VectorFst<ReverseArc> rfst; 247 Reverse(fst, &rfst); 248 vector<ReverseWeight> rdistance; 249 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); 250 ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, 251 AnyArcFilter<ReverseArc> > 252 ropts(&state_queue, rarc_filter); 253 ShortestDistance(rfst, &rdistance, ropts); 254 distance->clear(); 255 while (distance->size() < rdistance.size() - 1) 256 distance->push_back(rdistance[distance->size() + 1].Reverse()); 257 } 258 } 259 260 } // namespace fst 261 262 #endif // FST_LIB_SHORTEST_DISTANCE_H__ 263