1 // minimize.h 2 // minimize.h 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 // Copyright 2005-2010 Google, Inc. 17 // Author: johans (at) google.com (Johan Schalkwyk) 18 // 19 // \file Functions and classes to minimize a finite state acceptor 20 // 21 22 #ifndef FST_LIB_MINIMIZE_H__ 23 #define FST_LIB_MINIMIZE_H__ 24 25 #include <cmath> 26 27 #include <algorithm> 28 #include <map> 29 #include <queue> 30 #include <vector> 31 using std::vector; 32 33 #include <fst/arcsort.h> 34 #include <fst/connect.h> 35 #include <fst/dfs-visit.h> 36 #include <fst/encode.h> 37 #include <fst/factor-weight.h> 38 #include <fst/fst.h> 39 #include <fst/mutable-fst.h> 40 #include <fst/partition.h> 41 #include <fst/push.h> 42 #include <fst/queue.h> 43 #include <fst/reverse.h> 44 #include <fst/state-map.h> 45 46 47 namespace fst { 48 49 // comparator for creating partition based on sorting on 50 // - states 51 // - final weight 52 // - out degree, 53 // - (input label, output label, weight, destination_block) 54 template <class A> 55 class StateComparator { 56 public: 57 typedef typename A::StateId StateId; 58 typedef typename A::Weight Weight; 59 60 static const uint32 kCompareFinal = 0x00000001; 61 static const uint32 kCompareOutDegree = 0x00000002; 62 static const uint32 kCompareArcs = 0x00000004; 63 static const uint32 kCompareAll = 0x00000007; 64 65 StateComparator(const Fst<A>& fst, 66 const Partition<typename A::StateId>& partition, 67 uint32 flags = kCompareAll) 68 : fst_(fst), partition_(partition), flags_(flags) {} 69 70 // compare state x with state y based on sort criteria 71 bool operator()(const StateId x, const StateId y) const { 72 // check for final state equivalence 73 if (flags_ & kCompareFinal) { 74 const size_t xfinal = fst_.Final(x).Hash(); 75 const size_t yfinal = fst_.Final(y).Hash(); 76 if (xfinal < yfinal) return true; 77 else if (xfinal > yfinal) return false; 78 } 79 80 if (flags_ & kCompareOutDegree) { 81 // check for # arcs 82 if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true; 83 if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false; 84 85 if (flags_ & kCompareArcs) { 86 // # arcs are equal, check for arc match 87 for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y); 88 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) { 89 const A& arc1 = aiter1.Value(); 90 const A& arc2 = aiter2.Value(); 91 if (arc1.ilabel < arc2.ilabel) return true; 92 if (arc1.ilabel > arc2.ilabel) return false; 93 94 if (partition_.class_id(arc1.nextstate) < 95 partition_.class_id(arc2.nextstate)) return true; 96 if (partition_.class_id(arc1.nextstate) > 97 partition_.class_id(arc2.nextstate)) return false; 98 } 99 } 100 } 101 102 return false; 103 } 104 105 private: 106 const Fst<A>& fst_; 107 const Partition<typename A::StateId>& partition_; 108 const uint32 flags_; 109 }; 110 111 template <class A> const uint32 StateComparator<A>::kCompareFinal; 112 template <class A> const uint32 StateComparator<A>::kCompareOutDegree; 113 template <class A> const uint32 StateComparator<A>::kCompareArcs; 114 template <class A> const uint32 StateComparator<A>::kCompareAll; 115 116 117 // Computes equivalence classes for cyclic Fsts. For cyclic minimization 118 // we use the classic HopCroft minimization algorithm, which is of 119 // 120 // O(E)log(N), 121 // 122 // where E is the number of edges in the machine and N is number of states. 123 // 124 // The following paper describes the original algorithm 125 // An N Log N algorithm for minimizing states in a finite automaton 126 // by John HopCroft, January 1971 127 // 128 template <class A, class Queue> 129 class CyclicMinimizer { 130 public: 131 typedef typename A::Label Label; 132 typedef typename A::StateId StateId; 133 typedef typename A::StateId ClassId; 134 typedef typename A::Weight Weight; 135 typedef ReverseArc<A> RevA; 136 137 CyclicMinimizer(const ExpandedFst<A>& fst) { 138 Initialize(fst); 139 Compute(fst); 140 } 141 142 ~CyclicMinimizer() { 143 delete aiter_queue_; 144 } 145 146 const Partition<StateId>& partition() const { 147 return P_; 148 } 149 150 // helper classes 151 private: 152 typedef ArcIterator<Fst<RevA> > ArcIter; 153 class ArcIterCompare { 154 public: 155 ArcIterCompare(const Partition<StateId>& partition) 156 : partition_(partition) {} 157 158 ArcIterCompare(const ArcIterCompare& comp) 159 : partition_(comp.partition_) {} 160 161 // compare two iterators based on there input labels, and proto state 162 // (partition class Ids) 163 bool operator()(const ArcIter* x, const ArcIter* y) const { 164 const RevA& xarc = x->Value(); 165 const RevA& yarc = y->Value(); 166 return (xarc.ilabel > yarc.ilabel); 167 } 168 169 private: 170 const Partition<StateId>& partition_; 171 }; 172 173 typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare> 174 ArcIterQueue; 175 176 // helper methods 177 private: 178 // prepartitions the space into equivalence classes with 179 // same final weight 180 // same # arcs per state 181 // same outgoing arcs 182 void PrePartition(const Fst<A>& fst) { 183 VLOG(5) << "PrePartition"; 184 185 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; 186 StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal); 187 EquivalenceMap equiv_map(comp); 188 189 StateIterator<Fst<A> > siter(fst); 190 StateId class_id = P_.AddClass(); 191 P_.Add(siter.Value(), class_id); 192 equiv_map[siter.Value()] = class_id; 193 L_.Enqueue(class_id); 194 for (siter.Next(); !siter.Done(); siter.Next()) { 195 StateId s = siter.Value(); 196 typename EquivalenceMap::const_iterator it = equiv_map.find(s); 197 if (it == equiv_map.end()) { 198 class_id = P_.AddClass(); 199 P_.Add(s, class_id); 200 equiv_map[s] = class_id; 201 L_.Enqueue(class_id); 202 } else { 203 P_.Add(s, it->second); 204 equiv_map[s] = it->second; 205 } 206 } 207 208 VLOG(5) << "Initial Partition: " << P_.num_classes(); 209 } 210 211 // - Create inverse transition Tr_ = rev(fst) 212 // - loop over states in fst and split on final, creating two blocks 213 // in the partition corresponding to final, non-final 214 void Initialize(const Fst<A>& fst) { 215 // construct Tr 216 Reverse(fst, &Tr_); 217 ILabelCompare<RevA> ilabel_comp; 218 ArcSort(&Tr_, ilabel_comp); 219 220 // initial split (F, S - F) 221 P_.Initialize(Tr_.NumStates() - 1); 222 223 // prep partition 224 PrePartition(fst); 225 226 // allocate arc iterator queue 227 ArcIterCompare comp(P_); 228 aiter_queue_ = new ArcIterQueue(comp); 229 } 230 231 // partition all classes with destination C 232 void Split(ClassId C) { 233 // Prep priority queue. Open arc iterator for each state in C, and 234 // insert into priority queue. 235 for (PartitionIterator<StateId> siter(P_, C); 236 !siter.Done(); siter.Next()) { 237 StateId s = siter.Value(); 238 if (Tr_.NumArcs(s + 1)) 239 aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1)); 240 } 241 242 // Now pop arc iterator from queue, split entering equivalence class 243 // re-insert updated iterator into queue. 244 Label prev_label = -1; 245 while (!aiter_queue_->empty()) { 246 ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top(); 247 aiter_queue_->pop(); 248 if (aiter->Done()) { 249 delete aiter; 250 continue; 251 } 252 253 const RevA& arc = aiter->Value(); 254 StateId from_state = aiter->Value().nextstate - 1; 255 Label from_label = arc.ilabel; 256 if (prev_label != from_label) 257 P_.FinalizeSplit(&L_); 258 259 StateId from_class = P_.class_id(from_state); 260 if (P_.class_size(from_class) > 1) 261 P_.SplitOn(from_state); 262 263 prev_label = from_label; 264 aiter->Next(); 265 if (aiter->Done()) 266 delete aiter; 267 else 268 aiter_queue_->push(aiter); 269 } 270 P_.FinalizeSplit(&L_); 271 } 272 273 // Main loop for hopcroft minimization. 274 void Compute(const Fst<A>& fst) { 275 // process active classes (FIFO, or FILO) 276 while (!L_.Empty()) { 277 ClassId C = L_.Head(); 278 L_.Dequeue(); 279 280 // split on C, all labels in C 281 Split(C); 282 } 283 } 284 285 // helper data 286 private: 287 // Partioning of states into equivalence classes 288 Partition<StateId> P_; 289 290 // L = set of active classes to be processed in partition P 291 Queue L_; 292 293 // reverse transition function 294 VectorFst<RevA> Tr_; 295 296 // Priority queue of open arc iterators for all states in the 'splitter' 297 // equivalence class 298 ArcIterQueue* aiter_queue_; 299 }; 300 301 302 // Computes equivalence classes for acyclic Fsts. The implementation details 303 // for this algorithms is documented by the following paper. 304 // 305 // Minimization of acyclic deterministic automata in linear time 306 // Dominque Revuz 307 // 308 // Complexity O(|E|) 309 // 310 template <class A> 311 class AcyclicMinimizer { 312 public: 313 typedef typename A::Label Label; 314 typedef typename A::StateId StateId; 315 typedef typename A::StateId ClassId; 316 typedef typename A::Weight Weight; 317 318 AcyclicMinimizer(const ExpandedFst<A>& fst) { 319 Initialize(fst); 320 Refine(fst); 321 } 322 323 const Partition<StateId>& partition() { 324 return partition_; 325 } 326 327 // helper classes 328 private: 329 // DFS visitor to compute the height (distance) to final state. 330 class HeightVisitor { 331 public: 332 HeightVisitor() : max_height_(0), num_states_(0) { } 333 334 // invoked before dfs visit 335 void InitVisit(const Fst<A>& fst) {} 336 337 // invoked when state is discovered (2nd arg is DFS tree root) 338 bool InitState(StateId s, StateId root) { 339 // extend height array and initialize height (distance) to 0 340 for (size_t i = height_.size(); i <= s; ++i) 341 height_.push_back(-1); 342 343 if (s >= num_states_) num_states_ = s + 1; 344 return true; 345 } 346 347 // invoked when tree arc examined (to undiscoverted state) 348 bool TreeArc(StateId s, const A& arc) { 349 return true; 350 } 351 352 // invoked when back arc examined (to unfinished state) 353 bool BackArc(StateId s, const A& arc) { 354 return true; 355 } 356 357 // invoked when forward or cross arc examined (to finished state) 358 bool ForwardOrCrossArc(StateId s, const A& arc) { 359 if (height_[arc.nextstate] + 1 > height_[s]) 360 height_[s] = height_[arc.nextstate] + 1; 361 return true; 362 } 363 364 // invoked when state finished (parent is kNoStateId for tree root) 365 void FinishState(StateId s, StateId parent, const A* parent_arc) { 366 if (height_[s] == -1) height_[s] = 0; 367 StateId h = height_[s] + 1; 368 if (parent >= 0) { 369 if (h > height_[parent]) height_[parent] = h; 370 if (h > max_height_) max_height_ = h; 371 } 372 } 373 374 // invoked after DFS visit 375 void FinishVisit() {} 376 377 size_t max_height() const { return max_height_; } 378 379 const vector<StateId>& height() const { return height_; } 380 381 const size_t num_states() const { return num_states_; } 382 383 private: 384 vector<StateId> height_; 385 size_t max_height_; 386 size_t num_states_; 387 }; 388 389 // helper methods 390 private: 391 // cluster states according to height (distance to final state) 392 void Initialize(const Fst<A>& fst) { 393 // compute height (distance to final state) 394 HeightVisitor hvisitor; 395 DfsVisit(fst, &hvisitor); 396 397 // create initial partition based on height 398 partition_.Initialize(hvisitor.num_states()); 399 partition_.AllocateClasses(hvisitor.max_height() + 1); 400 const vector<StateId>& hstates = hvisitor.height(); 401 for (size_t s = 0; s < hstates.size(); ++s) 402 partition_.Add(s, hstates[s]); 403 } 404 405 // refine states based on arc sort (out degree, arc equivalence) 406 void Refine(const Fst<A>& fst) { 407 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap; 408 StateComparator<A> comp(fst, partition_); 409 410 // start with tail (height = 0) 411 size_t height = partition_.num_classes(); 412 for (size_t h = 0; h < height; ++h) { 413 EquivalenceMap equiv_classes(comp); 414 415 // sort states within equivalence class 416 PartitionIterator<StateId> siter(partition_, h); 417 equiv_classes[siter.Value()] = h; 418 for (siter.Next(); !siter.Done(); siter.Next()) { 419 const StateId s = siter.Value(); 420 typename EquivalenceMap::const_iterator it = equiv_classes.find(s); 421 if (it == equiv_classes.end()) 422 equiv_classes[s] = partition_.AddClass(); 423 else 424 equiv_classes[s] = it->second; 425 } 426 427 // create refined partition 428 for (siter.Reset(); !siter.Done();) { 429 const StateId s = siter.Value(); 430 const StateId old_class = partition_.class_id(s); 431 const StateId new_class = equiv_classes[s]; 432 433 // a move operation can invalidate the iterator, so 434 // we first update the iterator to the next element 435 // before we move the current element out of the list 436 siter.Next(); 437 if (old_class != new_class) 438 partition_.Move(s, new_class); 439 } 440 } 441 } 442 443 private: 444 Partition<StateId> partition_; 445 }; 446 447 448 // Given a partition and a mutable fst, merge states of Fst inplace 449 // (i.e. destructively). Merging works by taking the first state in 450 // a class of the partition to be the representative state for the class. 451 // Each arc is then reconnected to this state. All states in the class 452 // are merged by adding there arcs to the representative state. 453 template <class A> 454 void MergeStates( 455 const Partition<typename A::StateId>& partition, MutableFst<A>* fst) { 456 typedef typename A::StateId StateId; 457 458 vector<StateId> state_map(partition.num_classes()); 459 for (size_t i = 0; i < partition.num_classes(); ++i) { 460 PartitionIterator<StateId> siter(partition, i); 461 state_map[i] = siter.Value(); // first state in partition; 462 } 463 464 // relabel destination states 465 for (size_t c = 0; c < partition.num_classes(); ++c) { 466 for (PartitionIterator<StateId> siter(partition, c); 467 !siter.Done(); siter.Next()) { 468 StateId s = siter.Value(); 469 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 470 !aiter.Done(); aiter.Next()) { 471 A arc = aiter.Value(); 472 arc.nextstate = state_map[partition.class_id(arc.nextstate)]; 473 474 if (s == state_map[c]) // first state just set destination 475 aiter.SetValue(arc); 476 else 477 fst->AddArc(state_map[c], arc); 478 } 479 } 480 } 481 fst->SetStart(state_map[partition.class_id(fst->Start())]); 482 483 Connect(fst); 484 } 485 486 template <class A> 487 void AcceptorMinimize(MutableFst<A>* fst) { 488 typedef typename A::StateId StateId; 489 if (!(fst->Properties(kAcceptor | kUnweighted, true))) { 490 FSTERROR() << "FST is not an unweighted acceptor"; 491 fst->SetProperties(kError, kError); 492 return; 493 } 494 495 // connect fst before minimization, handles disconnected states 496 Connect(fst); 497 if (fst->NumStates() == 0) return; 498 499 if (fst->Properties(kAcyclic, true)) { 500 // Acyclic minimization (revuz) 501 VLOG(2) << "Acyclic Minimization"; 502 ArcSort(fst, ILabelCompare<A>()); 503 AcyclicMinimizer<A> minimizer(*fst); 504 MergeStates(minimizer.partition(), fst); 505 506 } else { 507 // Cyclic minimizaton (hopcroft) 508 VLOG(2) << "Cyclic Minimization"; 509 CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst); 510 MergeStates(minimizer.partition(), fst); 511 } 512 513 // Merge in appropriate semiring 514 ArcUniqueMapper<A> mapper(*fst); 515 StateMap(fst, mapper); 516 } 517 518 519 // In place minimization of deterministic weighted automata and transducers. 520 // For transducers, then the 'sfst' argument is not null, the algorithm 521 // produces a compact factorization of the minimal transducer. 522 // 523 // In the acyclic case, we use an algorithm from Dominique Revuz that 524 // is linear in the number of arcs (edges) in the machine. 525 // Complexity = O(E) 526 // 527 // In the cyclic case, we use the classical hopcroft minimization. 528 // Complexity = O(|E|log(|N|) 529 // 530 template <class A> 531 void Minimize(MutableFst<A>* fst, 532 MutableFst<A>* sfst = 0, 533 float delta = kDelta) { 534 uint64 props = fst->Properties(kAcceptor | kIDeterministic| 535 kWeighted | kUnweighted, true); 536 if (!(props & kIDeterministic)) { 537 FSTERROR() << "FST is not deterministic"; 538 fst->SetProperties(kError, kError); 539 return; 540 } 541 542 if (!(props & kAcceptor)) { // weighted transducer 543 VectorFst< GallicArc<A, STRING_LEFT> > gfst; 544 ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>()); 545 fst->DeleteStates(); 546 gfst.SetProperties(kAcceptor, kAcceptor); 547 Push(&gfst, REWEIGHT_TO_INITIAL, delta); 548 ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta)); 549 EncodeMapper< GallicArc<A, STRING_LEFT> > 550 encoder(kEncodeLabels | kEncodeWeights, ENCODE); 551 Encode(&gfst, &encoder); 552 AcceptorMinimize(&gfst); 553 Decode(&gfst, encoder); 554 555 if (sfst == 0) { 556 FactorWeightFst< GallicArc<A, STRING_LEFT>, 557 GallicFactor<typename A::Label, 558 typename A::Weight, STRING_LEFT> > fwfst(gfst); 559 SymbolTable *osyms = fst->OutputSymbols() ? 560 fst->OutputSymbols()->Copy() : 0; 561 ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>()); 562 fst->SetOutputSymbols(osyms); 563 delete osyms; 564 } else { 565 sfst->SetOutputSymbols(fst->OutputSymbols()); 566 GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst); 567 ArcMap(gfst, fst, &mapper); 568 fst->SetOutputSymbols(sfst->InputSymbols()); 569 } 570 } else if (props & kWeighted) { // weighted acceptor 571 Push(fst, REWEIGHT_TO_INITIAL, delta); 572 ArcMap(fst, QuantizeMapper<A>(delta)); 573 EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE); 574 Encode(fst, &encoder); 575 AcceptorMinimize(fst); 576 Decode(fst, encoder); 577 } else { // unweighted acceptor 578 AcceptorMinimize(fst); 579 } 580 } 581 582 } // namespace fst 583 584 #endif // FST_LIB_MINIMIZE_H__ 585