1 // shortest-distance.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Functions and classes to find shortest distance in an FST. 20 21 #ifndef FST_LIB_SHORTEST_DISTANCE_H__ 22 #define FST_LIB_SHORTEST_DISTANCE_H__ 23 24 #include <deque> 25 using std::deque; 26 #include <vector> 27 using std::vector; 28 29 #include <fst/arcfilter.h> 30 #include <fst/cache.h> 31 #include <fst/queue.h> 32 #include <fst/reverse.h> 33 #include <fst/test-properties.h> 34 35 36 namespace fst { 37 38 template <class Arc, class Queue, class ArcFilter> 39 struct ShortestDistanceOptions { 40 typedef typename Arc::StateId StateId; 41 42 Queue *state_queue; // Queue discipline used; owned by caller 43 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) 44 StateId source; // If kNoStateId, use the Fst's initial state 45 float delta; // Determines the degree of convergence required 46 bool first_path; // For a semiring with the path property (o.w. 47 // undefined), compute the shortest-distances along 48 // along the first path to a final state found 49 // by the algorithm. That path is the shortest-path 50 // only if the FST has a unique final state (or all 51 // the final states have the same final weight), the 52 // queue discipline is shortest-first and all the 53 // weights in the FST are between One() and Zero() 54 // according to NaturalLess. 55 56 ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, 57 float d = kDelta) 58 : state_queue(q), arc_filter(filt), source(src), delta(d), 59 first_path(false) {} 60 }; 61 62 63 // Computation state of the shortest-distance algorithm. Reusable 64 // information is maintained across calls to member function 65 // ShortestDistance(source) when 'retain' is true for improved 66 // efficiency when calling multiple times from different source states 67 // (e.g., in epsilon removal). Contrary to usual conventions, 'fst' 68 // may not be freed before this class. Vector 'distance' should not be 69 // modified by the user between these calls. 70 // The Error() method returns true if an error was encountered. 71 template<class Arc, class Queue, class ArcFilter> 72 class ShortestDistanceState { 73 public: 74 typedef typename Arc::StateId StateId; 75 typedef typename Arc::Weight Weight; 76 77 ShortestDistanceState( 78 const Fst<Arc> &fst, 79 vector<Weight> *distance, 80 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, 81 bool retain) 82 : fst_(fst), distance_(distance), state_queue_(opts.state_queue), 83 arc_filter_(opts.arc_filter), delta_(opts.delta), 84 first_path_(opts.first_path), retain_(retain), source_id_(0), 85 error_(false) { 86 distance_->clear(); 87 } 88 89 ~ShortestDistanceState() {} 90 91 void ShortestDistance(StateId source); 92 93 bool Error() const { return error_; } 94 95 private: 96 const Fst<Arc> &fst_; 97 vector<Weight> *distance_; 98 Queue *state_queue_; 99 ArcFilter arc_filter_; 100 float delta_; 101 bool first_path_; 102 bool retain_; // Retain and reuse information across calls 103 104 vector<Weight> rdistance_; // Relaxation distance. 105 vector<bool> enqueued_; // Is state enqueued? 106 vector<StateId> sources_; // Source ID for ith state in 'distance_', 107 // 'rdistance_', and 'enqueued_' if retained. 108 StateId source_id_; // Unique ID characterizing each call to SD 109 110 bool error_; 111 }; 112 113 // Compute the shortest distance. If 'source' is kNoStateId, use 114 // the initial state of the Fst. 115 template <class Arc, class Queue, class ArcFilter> 116 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( 117 StateId source) { 118 if (fst_.Start() == kNoStateId) { 119 if (fst_.Properties(kError, false)) error_ = true; 120 return; 121 } 122 123 if (!(Weight::Properties() & kRightSemiring)) { 124 FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " 125 << Weight::Type(); 126 error_ = true; 127 return; 128 } 129 130 if (first_path_ && !(Weight::Properties() & kPath)) { 131 FSTERROR() << "ShortestDistance: first_path option disallowed when " 132 << "Weight does not have the path property: " 133 << Weight::Type(); 134 error_ = true; 135 return; 136 } 137 138 state_queue_->Clear(); 139 140 if (!retain_) { 141 distance_->clear(); 142 rdistance_.clear(); 143 enqueued_.clear(); 144 } 145 146 if (source == kNoStateId) 147 source = fst_.Start(); 148 149 while (distance_->size() <= source) { 150 distance_->push_back(Weight::Zero()); 151 rdistance_.push_back(Weight::Zero()); 152 enqueued_.push_back(false); 153 } 154 if (retain_) { 155 while (sources_.size() <= source) 156 sources_.push_back(kNoStateId); 157 sources_[source] = source_id_; 158 } 159 (*distance_)[source] = Weight::One(); 160 rdistance_[source] = Weight::One(); 161 enqueued_[source] = true; 162 163 state_queue_->Enqueue(source); 164 165 while (!state_queue_->Empty()) { 166 StateId s = state_queue_->Head(); 167 state_queue_->Dequeue(); 168 while (distance_->size() <= s) { 169 distance_->push_back(Weight::Zero()); 170 rdistance_.push_back(Weight::Zero()); 171 enqueued_.push_back(false); 172 } 173 if (first_path_ && (fst_.Final(s) != Weight::Zero())) 174 break; 175 enqueued_[s] = false; 176 Weight r = rdistance_[s]; 177 rdistance_[s] = Weight::Zero(); 178 for (ArcIterator< Fst<Arc> > aiter(fst_, s); 179 !aiter.Done(); 180 aiter.Next()) { 181 const Arc &arc = aiter.Value(); 182 if (!arc_filter_(arc)) 183 continue; 184 while (distance_->size() <= arc.nextstate) { 185 distance_->push_back(Weight::Zero()); 186 rdistance_.push_back(Weight::Zero()); 187 enqueued_.push_back(false); 188 } 189 if (retain_) { 190 while (sources_.size() <= arc.nextstate) 191 sources_.push_back(kNoStateId); 192 if (sources_[arc.nextstate] != source_id_) { 193 (*distance_)[arc.nextstate] = Weight::Zero(); 194 rdistance_[arc.nextstate] = Weight::Zero(); 195 enqueued_[arc.nextstate] = false; 196 sources_[arc.nextstate] = source_id_; 197 } 198 } 199 Weight &nd = (*distance_)[arc.nextstate]; 200 Weight &nr = rdistance_[arc.nextstate]; 201 Weight w = Times(r, arc.weight); 202 if (!ApproxEqual(nd, Plus(nd, w), delta_)) { 203 nd = Plus(nd, w); 204 nr = Plus(nr, w); 205 if (!nd.Member() || !nr.Member()) { 206 error_ = true; 207 return; 208 } 209 if (!enqueued_[arc.nextstate]) { 210 state_queue_->Enqueue(arc.nextstate); 211 enqueued_[arc.nextstate] = true; 212 } else { 213 state_queue_->Update(arc.nextstate); 214 } 215 } 216 } 217 } 218 ++source_id_; 219 if (fst_.Properties(kError, false)) error_ = true; 220 } 221 222 223 // Shortest-distance algorithm: this version allows fine control 224 // via the options argument. See below for a simpler interface. 225 // 226 // This computes the shortest distance from the 'opts.source' state to 227 // each visited state S and stores the value in the 'distance' vector. 228 // An unvisited state S has distance Zero(), which will be stored in 229 // the 'distance' vector if S is less than the maximum visited state. 230 // The state queue discipline, arc filter, and convergence delta are 231 // taken in the options argument. 232 // The 'distance' vector will contain a unique element for which 233 // Member() is false if an error was encountered. 234 // 235 // The weights must must be right distributive and k-closed (i.e., 1 + 236 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). 237 // 238 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for 239 // Shortest-Distance Problems", Journal of Automata, Languages and 240 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm 241 // depends on the properties of the semiring and the queue discipline 242 // used. Refer to the paper for more details. 243 template<class Arc, class Queue, class ArcFilter> 244 void ShortestDistance( 245 const Fst<Arc> &fst, 246 vector<typename Arc::Weight> *distance, 247 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { 248 249 ShortestDistanceState<Arc, Queue, ArcFilter> 250 sd_state(fst, distance, opts, false); 251 sd_state.ShortestDistance(opts.source); 252 if (sd_state.Error()) { 253 distance->clear(); 254 distance->resize(1, Arc::Weight::NoWeight()); 255 } 256 } 257 258 // Shortest-distance algorithm: simplified interface. See above for a 259 // version that allows finer control. 260 // 261 // If 'reverse' is false, this computes the shortest distance from the 262 // initial state to each state S and stores the value in the 263 // 'distance' vector. If 'reverse' is true, this computes the shortest 264 // distance from each state to the final states. An unvisited state S 265 // has distance Zero(), which will be stored in the 'distance' vector 266 // if S is less than the maximum visited state. The state queue 267 // discipline is automatically-selected. 268 // The 'distance' vector will contain a unique element for which 269 // Member() is false if an error was encountered. 270 // 271 // The weights must must be right (left) distributive if reverse is 272 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + 273 // x + x^2 + ... + x^k). 274 // 275 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for 276 // Shortest-Distance Problems", Journal of Automata, Languages and 277 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm 278 // depends on the properties of the semiring and the queue discipline 279 // used. Refer to the paper for more details. 280 template<class Arc> 281 void ShortestDistance(const Fst<Arc> &fst, 282 vector<typename Arc::Weight> *distance, 283 bool reverse = false, 284 float delta = kDelta) { 285 typedef typename Arc::StateId StateId; 286 typedef typename Arc::Weight Weight; 287 288 if (!reverse) { 289 AnyArcFilter<Arc> arc_filter; 290 AutoQueue<StateId> state_queue(fst, distance, arc_filter); 291 ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > 292 opts(&state_queue, arc_filter); 293 opts.delta = delta; 294 ShortestDistance(fst, distance, opts); 295 } else { 296 typedef ReverseArc<Arc> ReverseArc; 297 typedef typename ReverseArc::Weight ReverseWeight; 298 AnyArcFilter<ReverseArc> rarc_filter; 299 VectorFst<ReverseArc> rfst; 300 Reverse(fst, &rfst); 301 vector<ReverseWeight> rdistance; 302 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); 303 ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, 304 AnyArcFilter<ReverseArc> > 305 ropts(&state_queue, rarc_filter); 306 ropts.delta = delta; 307 ShortestDistance(rfst, &rdistance, ropts); 308 distance->clear(); 309 if (rdistance.size() == 1 && !rdistance[0].Member()) { 310 distance->resize(1, Arc::Weight::NoWeight()); 311 return; 312 } 313 while (distance->size() < rdistance.size() - 1) 314 distance->push_back(rdistance[distance->size() + 1].Reverse()); 315 } 316 } 317 318 319 // Return the sum of the weight of all successful paths in an FST, i.e., 320 // the shortest-distance from the initial state to the final states. 321 // Returns a weight such that Member() is false if an error was encountered. 322 template <class Arc> 323 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) { 324 typedef typename Arc::Weight Weight; 325 typedef typename Arc::StateId StateId; 326 vector<Weight> distance; 327 if (Weight::Properties() & kRightSemiring) { 328 ShortestDistance(fst, &distance, false, delta); 329 if (distance.size() == 1 && !distance[0].Member()) 330 return Arc::Weight::NoWeight(); 331 Weight sum = Weight::Zero(); 332 for (StateId s = 0; s < distance.size(); ++s) 333 sum = Plus(sum, Times(distance[s], fst.Final(s))); 334 return sum; 335 } else { 336 ShortestDistance(fst, &distance, true, delta); 337 StateId s = fst.Start(); 338 if (distance.size() == 1 && !distance[0].Member()) 339 return Arc::Weight::NoWeight(); 340 return s != kNoStateId && s < distance.size() ? 341 distance[s] : Weight::Zero(); 342 } 343 } 344 345 346 } // namespace fst 347 348 #endif // FST_LIB_SHORTEST_DISTANCE_H__ 349