1 // minimize.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file Functions and classes to minimize a finite state acceptor 17 18 #ifndef FST_LIB_MINIMIZE_H__ 19 #define FST_LIB_MINIMIZE_H__ 20 21 #include <algorithm> 22 #include <map> 23 #include <queue> 24 25 #include "fst/lib/arcsort.h" 26 #include "fst/lib/arcsum.h" 27 #include "fst/lib/connect.h" 28 #include "fst/lib/dfs-visit.h" 29 #include "fst/lib/encode.h" 30 #include "fst/lib/factor-weight.h" 31 #include "fst/lib/fst.h" 32 #include "fst/lib/mutable-fst.h" 33 #include "fst/lib/partition.h" 34 #include "fst/lib/push.h" 35 #include "fst/lib/queue.h" 36 #include "fst/lib/reverse.h" 37 38 namespace fst { 39 40 // comparator for creating partition based on sorting on 41 // - states 42 // - final weight 43 // - out degree, 44 // - (input label, output label, weight, destination_block) 45 template <class A> 46 class StateComparator { 47 public: 48 typedef typename A::StateId StateId; 49 typedef typename A::Weight Weight; 50 51 static const int32 kCompareFinal = 0x0000001; 52 static const int32 kCompareOutDegree = 0x0000002; 53 static const int32 kCompareArcs = 0x0000004; 54 static const int32 kCompareAll = (kCompareFinal | 55 kCompareOutDegree | 56 kCompareArcs); 57 58 StateComparator(const Fst<A>& fst, 59 const Partition<typename A::StateId>& partition, 60 int32 flags = kCompareAll) 61 : fst_(fst), partition_(partition), flags_(flags) {} 62 63 // compare state x with state y based on sort criteria 64 bool operator()(const StateId x, const StateId y) const { 65 // check for final state equivalence 66 if (flags_ & kCompareFinal) { 67 const ssize_t xfinal = fst_.Final(x).Hash(); 68 const ssize_t yfinal = fst_.Final(y).Hash(); 69 if (xfinal < yfinal) return true; 70 else if (xfinal > yfinal) return false; 71 } 72 73 if (flags_ & kCompareOutDegree) { 74 // check for # arcs 75 if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true; 76 if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false; 77 78 if (flags_ & kCompareArcs) { 79 // # arcs are equal, check for arc match 80 for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y); 81 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) { 82 const A& arc1 = aiter1.Value(); 83 const A& arc2 = aiter2.Value(); 84 if (arc1.ilabel < arc2.ilabel) return true; 85 if (arc1.ilabel > arc2.ilabel) return false; 86 87 if (partition_.class_id(arc1.nextstate) < 88 partition_.class_id(arc2.nextstate)) return true; 89 if (partition_.class_id(arc1.nextstate) > 90 partition_.class_id(arc2.nextstate)) return false; 91 } 92 } 93 } 94 95 return false; 96 } 97 98 private: 99 const Fst<A>& fst_; 100 const Partition<typename A::StateId>& partition_; 101 const int32 flags_; 102 }; 103 104 // Computes equivalence classes for cyclic Fsts. For cyclic minimization 105 // we use the classic HopCroft minimization algorithm, which is of 106 // 107 // O(E)log(N), 108 // 109 // where E is the number of edges in the machine and N is number of states. 110 // 111 // The following paper describes the original algorithm 112 // An N Log N algorithm for minimizing states in a finite automaton 113 // by John HopCroft, January 1971 114 // 115 template <class A, class Queue> 116 class CyclicMinimizer { 117 public: 118 typedef typename A::Label Label; 119 typedef typename A::StateId StateId; 120 typedef typename A::StateId ClassId; 121 typedef typename A::Weight Weight; 122 typedef ReverseArc<A> RevA; 123 124 CyclicMinimizer(const ExpandedFst<A>& fst) { 125 Initialize(fst); 126 Compute(fst); 127 } 128 129 ~CyclicMinimizer() { 130 delete aiter_queue_; 131 } 132 133 const Partition<StateId>& partition() const { 134 return P_; 135 } 136 137 // helper classes 138 private: 139 typedef ArcIterator<Fst<RevA> > ArcIter; 140 class ArcIterCompare { 141 public: 142 ArcIterCompare(const Partition<StateId>& partition) 143 : partition_(partition) {} 144 145 ArcIterCompare(const ArcIterCompare& comp) 146 : partition_(comp.partition_) {} 147 148 // compare two iterators based on there input labels, and proto state 149 // (partition class Ids) 150 bool operator()(const ArcIter* x, const ArcIter* y) const { 151 const RevA& xarc = x->Value(); 152 const RevA& yarc = y->Value(); 153 return (xarc.ilabel > yarc.ilabel); 154 } 155 156 private: 157 const Partition<StateId>& partition_; 158 }; 159 160 typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare> 161 ArcIterQueue; 162 163 // helper methods 164 private: 165 // prepartitions the space into equivalence classes with 166 // same final weight 167 // same # arcs per state 168 // same outgoing arcs 169 void PrePartition(const Fst<A>& fst) { 170 VLOG(5) << "PrePartition"; 171 172 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; 173 StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal); 174 EquivalenceMap equiv_map(comp); 175 176 StateIterator<Fst<A> > siter(fst); 177 StateId class_id = P_.AddClass(); 178 P_.Add(siter.Value(), class_id); 179 equiv_map[siter.Value()] = class_id; 180 L_.Enqueue(class_id); 181 for (siter.Next(); !siter.Done(); siter.Next()) { 182 StateId s = siter.Value(); 183 typename EquivalenceMap::const_iterator it = equiv_map.find(s); 184 if (it == equiv_map.end()) { 185 class_id = P_.AddClass(); 186 P_.Add(s, class_id); 187 equiv_map[s] = class_id; 188 L_.Enqueue(class_id); 189 } else { 190 P_.Add(s, it->second); 191 equiv_map[s] = it->second; 192 } 193 } 194 195 VLOG(5) << "Initial Partition: " << P_.num_classes(); 196 } 197 198 // - Create inverse transition Tr_ = rev(fst) 199 // - loop over states in fst and split on final, creating two blocks 200 // in the partition corresponding to final, non-final 201 void Initialize(const Fst<A>& fst) { 202 // construct Tr 203 Reverse(fst, &Tr_); 204 ILabelCompare<RevA> ilabel_comp; 205 ArcSort(&Tr_, ilabel_comp); 206 207 // initial split (F, S - F) 208 P_.Initialize(Tr_.NumStates() - 1); 209 210 // prep partition 211 PrePartition(fst); 212 213 // allocate arc iterator queue 214 ArcIterCompare comp(P_); 215 aiter_queue_ = new ArcIterQueue(comp); 216 } 217 218 // partition all classes with destination C 219 void Split(ClassId C) { 220 // Prep priority queue. Open arc iterator for each state in C, and 221 // insert into priority queue. 222 for (PartitionIterator<StateId> siter(P_, C); 223 !siter.Done(); siter.Next()) { 224 StateId s = siter.Value(); 225 if (Tr_.NumArcs(s + 1)) 226 aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1)); 227 } 228 229 // Now pop arc iterator from queue, split entering equivalence class 230 // re-insert updated iterator into queue. 231 Label prev_label = -1; 232 while (!aiter_queue_->empty()) { 233 ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top(); 234 aiter_queue_->pop(); 235 if (aiter->Done()) { 236 delete aiter; 237 continue; 238 } 239 240 const RevA& arc = aiter->Value(); 241 StateId from_state = aiter->Value().nextstate - 1; 242 Label from_label = arc.ilabel; 243 if (prev_label != from_label) 244 P_.FinalizeSplit(&L_); 245 246 StateId from_class = P_.class_id(from_state); 247 if (P_.class_size(from_class) > 1) 248 P_.SplitOn(from_state); 249 250 prev_label = from_label; 251 aiter->Next(); 252 if (aiter->Done()) 253 delete aiter; 254 else 255 aiter_queue_->push(aiter); 256 } 257 P_.FinalizeSplit(&L_); 258 } 259 260 // Main loop for hopcroft minimization. 261 void Compute(const Fst<A>& fst) { 262 // process active classes (FIFO, or FILO) 263 while (!L_.Empty()) { 264 ClassId C = L_.Head(); 265 L_.Dequeue(); 266 267 // split on C, all labels in C 268 Split(C); 269 } 270 } 271 272 // helper data 273 private: 274 // Partioning of states into equivalence classes 275 Partition<StateId> P_; 276 277 // L = set of active classes to be processed in partition P 278 Queue L_; 279 280 // reverse transition function 281 VectorFst<RevA> Tr_; 282 283 // Priority queue of open arc iterators for all states in the 'splitter' 284 // equivalence class 285 ArcIterQueue* aiter_queue_; 286 }; 287 288 289 // Computes equivalence classes for acyclic Fsts. The implementation details 290 // for this algorithms is documented by the following paper. 291 // 292 // Minimization of acyclic deterministic automata in linear time 293 // Dominque Revuz 294 // 295 // Complexity O(|E|) 296 // 297 template <class A> 298 class AcyclicMinimizer { 299 public: 300 typedef typename A::Label Label; 301 typedef typename A::StateId StateId; 302 typedef typename A::StateId ClassId; 303 typedef typename A::Weight Weight; 304 305 AcyclicMinimizer(const ExpandedFst<A>& fst) { 306 Initialize(fst); 307 Refine(fst); 308 } 309 310 const Partition<StateId>& partition() { 311 return partition_; 312 } 313 314 // helper classes 315 private: 316 // DFS visitor to compute the height (distance) to final state. 317 class HeightVisitor { 318 public: 319 HeightVisitor() : max_height_(0), num_states_(0) { } 320 321 // invoked before dfs visit 322 void InitVisit(const Fst<A>& fst) {} 323 324 // invoked when state is discovered (2nd arg is DFS tree root) 325 bool InitState(StateId s, StateId root) { 326 // extend height array and initialize height (distance) to 0 327 for (size_t i = height_.size(); i <= (size_t)s; ++i) 328 height_.push_back(-1); 329 330 if (s >= (StateId)num_states_) num_states_ = s + 1; 331 return true; 332 } 333 334 // invoked when tree arc examined (to undiscoverted state) 335 bool TreeArc(StateId s, const A& arc) { 336 return true; 337 } 338 339 // invoked when back arc examined (to unfinished state) 340 bool BackArc(StateId s, const A& arc) { 341 return true; 342 } 343 344 // invoked when forward or cross arc examined (to finished state) 345 bool ForwardOrCrossArc(StateId s, const A& arc) { 346 if (height_[arc.nextstate] + 1 > height_[s]) 347 height_[s] = height_[arc.nextstate] + 1; 348 return true; 349 } 350 351 // invoked when state finished (parent is kNoStateId for tree root) 352 void FinishState(StateId s, StateId parent, const A* parent_arc) { 353 if (height_[s] == -1) height_[s] = 0; 354 StateId h = height_[s] + 1; 355 if (parent >= 0) { 356 if (h > height_[parent]) height_[parent] = h; 357 if (h > (StateId)max_height_) max_height_ = h; 358 } 359 } 360 361 // invoked after DFS visit 362 void FinishVisit() {} 363 364 size_t max_height() const { return max_height_; } 365 366 const vector<StateId>& height() const { return height_; } 367 368 const size_t num_states() const { return num_states_; } 369 370 private: 371 vector<StateId> height_; 372 size_t max_height_; 373 size_t num_states_; 374 }; 375 376 // helper methods 377 private: 378 // cluster states according to height (distance to final state) 379 void Initialize(const Fst<A>& fst) { 380 // compute height (distance to final state) 381 HeightVisitor hvisitor; 382 DfsVisit(fst, &hvisitor); 383 384 // create initial partition based on height 385 partition_.Initialize(hvisitor.num_states()); 386 partition_.AllocateClasses(hvisitor.max_height() + 1); 387 const vector<StateId>& hstates = hvisitor.height(); 388 for (size_t s = 0; s < hstates.size(); ++s) 389 partition_.Add(s, hstates[s]); 390 } 391 392 // refine states based on arc sort (out degree, arc equivalence) 393 void Refine(const Fst<A>& fst) { 394 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; 395 StateComparator<A> comp(fst, partition_); 396 397 // start with tail (height = 0) 398 size_t height = partition_.num_classes(); 399 for (size_t h = 0; h < height; ++h) { 400 EquivalenceMap equiv_classes(comp); 401 402 // sort states within equivalence class 403 PartitionIterator<StateId> siter(partition_, h); 404 equiv_classes[siter.Value()] = h; 405 for (siter.Next(); !siter.Done(); siter.Next()) { 406 const StateId s = siter.Value(); 407 typename EquivalenceMap::const_iterator it = equiv_classes.find(s); 408 if (it == equiv_classes.end()) 409 equiv_classes[s] = partition_.AddClass(); 410 else 411 equiv_classes[s] = it->second; 412 } 413 414 // create refined partition 415 for (siter.Reset(); !siter.Done();) { 416 const StateId s = siter.Value(); 417 const StateId old_class = partition_.class_id(s); 418 const StateId new_class = equiv_classes[s]; 419 420 // a move operation can invalidate the iterator, so 421 // we first update the iterator to the next element 422 // before we move the current element out of the list 423 siter.Next(); 424 if (old_class != new_class) 425 partition_.Move(s, new_class); 426 } 427 } 428 } 429 430 private: 431 Partition<StateId> partition_; 432 }; 433 434 435 // Given a partition and a mutable fst, merge states of Fst inplace 436 // (i.e. destructively). Merging works by taking the first state in 437 // a class of the partition to be the representative state for the class. 438 // Each arc is then reconnected to this state. All states in the class 439 // are merged by adding there arcs to the representative state. 440 template <class A> 441 void MergeStates( 442 const Partition<typename A::StateId>& partition, MutableFst<A>* fst) { 443 typedef typename A::StateId StateId; 444 445 vector<StateId> state_map(partition.num_classes()); 446 for (size_t i = 0; i < (size_t)partition.num_classes(); ++i) { 447 PartitionIterator<StateId> siter(partition, i); 448 state_map[i] = siter.Value(); // first state in partition; 449 } 450 451 // relabel destination states 452 for (size_t c = 0; c < (size_t)partition.num_classes(); ++c) { 453 for (PartitionIterator<StateId> siter(partition, c); 454 !siter.Done(); siter.Next()) { 455 StateId s = siter.Value(); 456 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 457 !aiter.Done(); aiter.Next()) { 458 A arc = aiter.Value(); 459 arc.nextstate = state_map[partition.class_id(arc.nextstate)]; 460 461 if (s == state_map[c]) // first state just set destination 462 aiter.SetValue(arc); 463 else 464 fst->AddArc(state_map[c], arc); 465 } 466 } 467 } 468 fst->SetStart(state_map[partition.class_id(fst->Start())]); 469 470 Connect(fst); 471 } 472 473 template <class A> 474 void AcceptorMinimize(MutableFst<A>* fst) { 475 typedef typename A::StateId StateId; 476 if (!(fst->Properties(kAcceptor | kUnweighted, true))) 477 LOG(FATAL) << "Input Fst is not an unweighted acceptor"; 478 479 // connect fst before minimization, handles disconnected states 480 Connect(fst); 481 if (fst->NumStates() == 0) return; 482 483 if (fst->Properties(kAcyclic, true)) { 484 // Acyclic minimization (revuz) 485 VLOG(2) << "Acyclic Minimization"; 486 AcyclicMinimizer<A> minimizer(*fst); 487 MergeStates(minimizer.partition(), fst); 488 489 } else { 490 // Cyclic minimizaton (hopcroft) 491 VLOG(2) << "Cyclic Minimization"; 492 CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst); 493 MergeStates(minimizer.partition(), fst); 494 } 495 496 // sort arcs before summing 497 ArcSort(fst, ILabelCompare<A>()); 498 499 // sum in appropriate semiring 500 ArcSum(fst); 501 } 502 503 504 // In place minimization of unweighted, deterministic acceptors 505 // 506 // For acyclic automata we use an algorithm from Dominique Revuz that is 507 // linear in the number of arcs (edges) in the machine. 508 // Complexity = O(E) 509 // 510 // For cyclic automata we use the classical hopcroft minimization. 511 // Complexity = O(|E|log(|N|) 512 // 513 template <class A> 514 void Minimize(MutableFst<A>* fst, MutableFst<A>* sfst = 0) { 515 uint64 props = fst->Properties(kAcceptor | kIDeterministic| 516 kWeighted | kUnweighted, true); 517 if (!(props & kIDeterministic)) 518 LOG(FATAL) << "Input Fst is not deterministic"; 519 520 if (!(props & kAcceptor)) { // weighted transducer 521 VectorFst< GallicArc<A, STRING_LEFT> > gfst; 522 Map(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>()); 523 fst->DeleteStates(); 524 gfst.SetProperties(kAcceptor, kAcceptor); 525 Push(&gfst, REWEIGHT_TO_INITIAL); 526 Map(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >()); 527 EncodeMapper< GallicArc<A, STRING_LEFT> > 528 encoder(kEncodeLabels | kEncodeWeights, ENCODE); 529 Encode(&gfst, &encoder); 530 AcceptorMinimize(&gfst); 531 Decode(&gfst, encoder); 532 533 if (sfst == 0) { 534 FactorWeightFst< GallicArc<A, STRING_LEFT>, 535 GallicFactor<typename A::Label, 536 typename A::Weight, STRING_LEFT> > fwfst(gfst); 537 Map(fwfst, fst, FromGallicMapper<A, STRING_LEFT>()); 538 } else { 539 sfst->SetOutputSymbols(fst->OutputSymbols()); 540 GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst); 541 Map(gfst, fst, mapper); 542 fst->SetOutputSymbols(sfst->InputSymbols()); 543 } 544 } else if (props & kWeighted) { // weighted acceptor 545 Push(fst, REWEIGHT_TO_INITIAL); 546 Map(fst, QuantizeMapper<A>()); 547 EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE); 548 Encode(fst, &encoder); 549 AcceptorMinimize(fst); 550 Decode(fst, encoder); 551 } else { // unweighted acceptor 552 AcceptorMinimize(fst); 553 } 554 } 555 556 } // namespace fst 557 558 #endif // FST_LIB_MINIMIZE_H__ 559