1 // queue.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Functions and classes for various Fst state queues with 20 // a unified interface. 21 22 #ifndef FST_LIB_QUEUE_H__ 23 #define FST_LIB_QUEUE_H__ 24 25 #include <deque> 26 using std::deque; 27 #include <vector> 28 using std::vector; 29 30 #include <fst/arcfilter.h> 31 #include <fst/connect.h> 32 #include <fst/heap.h> 33 #include <fst/topsort.h> 34 35 36 namespace fst { 37 38 // template <class S> 39 // class Queue { 40 // public: 41 // typedef typename S StateId; 42 // 43 // // Ctr: may need args (e.g., Fst, comparator) for some queues 44 // Queue(...); 45 // // Returns the head of the queue 46 // StateId Head() const; 47 // // Inserts a state 48 // void Enqueue(StateId s); 49 // // Removes the head of the queue 50 // void Dequeue(); 51 // // Updates ordering of state s when weight changes, if necessary 52 // void Update(StateId s); 53 // // Does the queue contain no elements? 54 // bool Empty() const; 55 // // Remove all states from queue 56 // void Clear(); 57 // }; 58 59 // State queue types. 60 enum QueueType { 61 TRIVIAL_QUEUE = 0, // Single state queue 62 FIFO_QUEUE = 1, // First-in, first-out queue 63 LIFO_QUEUE = 2, // Last-in, first-out queue 64 SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue 65 TOP_ORDER_QUEUE = 4, // Topologically-ordered queue 66 STATE_ORDER_QUEUE = 5, // State-ID ordered queue 67 SCC_QUEUE = 6, // Component graph top-ordered meta-queue 68 AUTO_QUEUE = 7, // Auto-selected queue 69 OTHER_QUEUE = 8 70 }; 71 72 73 // QueueBase, templated on the StateId, is the base class shared by the 74 // queues considered by AutoQueue. 75 template <class S> 76 class QueueBase { 77 public: 78 typedef S StateId; 79 80 QueueBase(QueueType type) : queue_type_(type), error_(false) {} 81 virtual ~QueueBase() {} 82 StateId Head() const { return Head_(); } 83 void Enqueue(StateId s) { Enqueue_(s); } 84 void Dequeue() { Dequeue_(); } 85 void Update(StateId s) { Update_(s); } 86 bool Empty() const { return Empty_(); } 87 void Clear() { Clear_(); } 88 QueueType Type() { return queue_type_; } 89 bool Error() const { return error_; } 90 void SetError(bool error) { error_ = error; } 91 92 private: 93 // This allows base-class virtual access to non-virtual derived- 94 // class members of the same name. It makes the derived class more 95 // efficient to use but unsafe to further derive. 96 virtual StateId Head_() const = 0; 97 virtual void Enqueue_(StateId s) = 0; 98 virtual void Dequeue_() = 0; 99 virtual void Update_(StateId s) = 0; 100 virtual bool Empty_() const = 0; 101 virtual void Clear_() = 0; 102 103 QueueType queue_type_; 104 bool error_; 105 }; 106 107 108 // Trivial queue discipline, templated on the StateId. You may enqueue 109 // at most one state at a time. It is used for strongly connected components 110 // with only one state and no self loops. 111 template <class S> 112 class TrivialQueue : public QueueBase<S> { 113 public: 114 typedef S StateId; 115 116 TrivialQueue() : QueueBase<S>(TRIVIAL_QUEUE), front_(kNoStateId) {} 117 StateId Head() const { return front_; } 118 void Enqueue(StateId s) { front_ = s; } 119 void Dequeue() { front_ = kNoStateId; } 120 void Update(StateId s) {} 121 bool Empty() const { return front_ == kNoStateId; } 122 void Clear() { front_ = kNoStateId; } 123 124 125 private: 126 // This allows base-class virtual access to non-virtual derived- 127 // class members of the same name. It makes the derived class more 128 // efficient to use but unsafe to further derive. 129 virtual StateId Head_() const { return Head(); } 130 virtual void Enqueue_(StateId s) { Enqueue(s); } 131 virtual void Dequeue_() { Dequeue(); } 132 virtual void Update_(StateId s) { Update(s); } 133 virtual bool Empty_() const { return Empty(); } 134 virtual void Clear_() { return Clear(); } 135 136 StateId front_; 137 }; 138 139 140 // First-in, first-out queue discipline, templated on the StateId. 141 template <class S> 142 class FifoQueue : public QueueBase<S>, public deque<S> { 143 public: 144 using deque<S>::back; 145 using deque<S>::push_front; 146 using deque<S>::pop_back; 147 using deque<S>::empty; 148 using deque<S>::clear; 149 150 typedef S StateId; 151 152 FifoQueue() : QueueBase<S>(FIFO_QUEUE) {} 153 StateId Head() const { return back(); } 154 void Enqueue(StateId s) { push_front(s); } 155 void Dequeue() { pop_back(); } 156 void Update(StateId s) {} 157 bool Empty() const { return empty(); } 158 void Clear() { clear(); } 159 160 private: 161 // This allows base-class virtual access to non-virtual derived- 162 // class members of the same name. It makes the derived class more 163 // efficient to use but unsafe to further derive. 164 virtual StateId Head_() const { return Head(); } 165 virtual void Enqueue_(StateId s) { Enqueue(s); } 166 virtual void Dequeue_() { Dequeue(); } 167 virtual void Update_(StateId s) { Update(s); } 168 virtual bool Empty_() const { return Empty(); } 169 virtual void Clear_() { return Clear(); } 170 }; 171 172 173 // Last-in, first-out queue discipline, templated on the StateId. 174 template <class S> 175 class LifoQueue : public QueueBase<S>, public deque<S> { 176 public: 177 using deque<S>::front; 178 using deque<S>::push_front; 179 using deque<S>::pop_front; 180 using deque<S>::empty; 181 using deque<S>::clear; 182 183 typedef S StateId; 184 185 LifoQueue() : QueueBase<S>(LIFO_QUEUE) {} 186 StateId Head() const { return front(); } 187 void Enqueue(StateId s) { push_front(s); } 188 void Dequeue() { pop_front(); } 189 void Update(StateId s) {} 190 bool Empty() const { return empty(); } 191 void Clear() { clear(); } 192 193 private: 194 // This allows base-class virtual access to non-virtual derived- 195 // class members of the same name. It makes the derived class more 196 // efficient to use but unsafe to further derive. 197 virtual StateId Head_() const { return Head(); } 198 virtual void Enqueue_(StateId s) { Enqueue(s); } 199 virtual void Dequeue_() { Dequeue(); } 200 virtual void Update_(StateId s) { Update(s); } 201 virtual bool Empty_() const { return Empty(); } 202 virtual void Clear_() { return Clear(); } 203 }; 204 205 206 // Shortest-first queue discipline, templated on the StateId and 207 // comparison function object. Comparison function object COMP is 208 // used to compare two StateIds. If a (single) state's order changes, 209 // it can be reordered in the queue with a call to Update(). 210 // If 'update == false', call to Update() does not reorder the queue. 211 template <typename S, typename C, bool update = true> 212 class ShortestFirstQueue : public QueueBase<S> { 213 public: 214 typedef S StateId; 215 typedef C Compare; 216 217 ShortestFirstQueue(C comp) 218 : QueueBase<S>(SHORTEST_FIRST_QUEUE), heap_(comp) {} 219 220 StateId Head() const { return heap_.Top(); } 221 222 void Enqueue(StateId s) { 223 if (update) { 224 for (StateId i = key_.size(); i <= s; ++i) 225 key_.push_back(kNoKey); 226 key_[s] = heap_.Insert(s); 227 } else { 228 heap_.Insert(s); 229 } 230 } 231 232 void Dequeue() { 233 if (update) 234 key_[heap_.Pop()] = kNoKey; 235 else 236 heap_.Pop(); 237 } 238 239 void Update(StateId s) { 240 if (!update) 241 return; 242 if (s >= key_.size() || key_[s] == kNoKey) { 243 Enqueue(s); 244 } else { 245 heap_.Update(key_[s], s); 246 } 247 } 248 249 bool Empty() const { return heap_.Empty(); } 250 251 void Clear() { 252 heap_.Clear(); 253 if (update) key_.clear(); 254 } 255 256 private: 257 Heap<S, C, false> heap_; 258 vector<ssize_t> key_; 259 260 // This allows base-class virtual access to non-virtual derived- 261 // class members of the same name. It makes the derived class more 262 // efficient to use but unsafe to further derive. 263 virtual StateId Head_() const { return Head(); } 264 virtual void Enqueue_(StateId s) { Enqueue(s); } 265 virtual void Dequeue_() { Dequeue(); } 266 virtual void Update_(StateId s) { Update(s); } 267 virtual bool Empty_() const { return Empty(); } 268 virtual void Clear_() { return Clear(); } 269 }; 270 271 272 // Given a vector that maps from states to weights and a Less 273 // comparison function object between weights, this class defines a 274 // comparison function object between states. 275 template <typename S, typename L> 276 class StateWeightCompare { 277 public: 278 typedef L Less; 279 typedef typename L::Weight Weight; 280 typedef S StateId; 281 282 StateWeightCompare(const vector<Weight>& weights, const L &less) 283 : weights_(weights), less_(less) {} 284 285 bool operator()(const S x, const S y) const { 286 return less_(weights_[x], weights_[y]); 287 } 288 289 private: 290 const vector<Weight>& weights_; 291 L less_; 292 }; 293 294 295 // Shortest-first queue discipline, templated on the StateId and Weight, is 296 // specialized to use the weight's natural order for the comparison function. 297 template <typename S, typename W> 298 class NaturalShortestFirstQueue : 299 public ShortestFirstQueue<S, StateWeightCompare<S, NaturalLess<W> > > { 300 public: 301 typedef StateWeightCompare<S, NaturalLess<W> > C; 302 303 NaturalShortestFirstQueue(const vector<W> &distance) : 304 ShortestFirstQueue<S, C>(C(distance, less_)) {} 305 306 private: 307 NaturalLess<W> less_; 308 }; 309 310 // Topological-order queue discipline, templated on the StateId. 311 // States are ordered in the queue topologically. The FST must be acyclic. 312 template <class S> 313 class TopOrderQueue : public QueueBase<S> { 314 public: 315 typedef S StateId; 316 317 // This constructor computes the top. order. It accepts an arc filter 318 // to limit the transitions considered in that computation (e.g., only 319 // the epsilon graph). 320 template <class Arc, class ArcFilter> 321 TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter) 322 : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), 323 order_(0), state_(0) { 324 bool acyclic; 325 TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic); 326 DfsVisit(fst, &top_order_visitor, filter); 327 if (!acyclic) { 328 FSTERROR() << "TopOrderQueue: fst is not acyclic."; 329 QueueBase<S>::SetError(true); 330 } 331 state_.resize(order_.size(), kNoStateId); 332 } 333 334 // This constructor is passed the top. order, useful when we know it 335 // beforehand. 336 TopOrderQueue(const vector<StateId> &order) 337 : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), 338 order_(order), state_(order.size(), kNoStateId) {} 339 340 StateId Head() const { return state_[front_]; } 341 342 void Enqueue(StateId s) { 343 if (front_ > back_) front_ = back_ = order_[s]; 344 else if (order_[s] > back_) back_ = order_[s]; 345 else if (order_[s] < front_) front_ = order_[s]; 346 state_[order_[s]] = s; 347 } 348 349 void Dequeue() { 350 state_[front_] = kNoStateId; 351 while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; 352 } 353 354 void Update(StateId s) {} 355 356 bool Empty() const { return front_ > back_; } 357 358 void Clear() { 359 for (StateId i = front_; i <= back_; ++i) state_[i] = kNoStateId; 360 back_ = kNoStateId; 361 front_ = 0; 362 } 363 364 private: 365 StateId front_; 366 StateId back_; 367 vector<StateId> order_; 368 vector<StateId> state_; 369 370 // This allows base-class virtual access to non-virtual derived- 371 // class members of the same name. It makes the derived class more 372 // efficient to use but unsafe to further derive. 373 virtual StateId Head_() const { return Head(); } 374 virtual void Enqueue_(StateId s) { Enqueue(s); } 375 virtual void Dequeue_() { Dequeue(); } 376 virtual void Update_(StateId s) { Update(s); } 377 virtual bool Empty_() const { return Empty(); } 378 virtual void Clear_() { return Clear(); } 379 }; 380 381 382 // State order queue discipline, templated on the StateId. 383 // States are ordered in the queue by state Id. 384 template <class S> 385 class StateOrderQueue : public QueueBase<S> { 386 public: 387 typedef S StateId; 388 389 StateOrderQueue() 390 : QueueBase<S>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} 391 392 StateId Head() const { return front_; } 393 394 void Enqueue(StateId s) { 395 if (front_ > back_) front_ = back_ = s; 396 else if (s > back_) back_ = s; 397 else if (s < front_) front_ = s; 398 while (enqueued_.size() <= s) enqueued_.push_back(false); 399 enqueued_[s] = true; 400 } 401 402 void Dequeue() { 403 enqueued_[front_] = false; 404 while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; 405 } 406 407 void Update(StateId s) {} 408 409 bool Empty() const { return front_ > back_; } 410 411 void Clear() { 412 for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; 413 front_ = 0; 414 back_ = kNoStateId; 415 } 416 417 private: 418 StateId front_; 419 StateId back_; 420 vector<bool> enqueued_; 421 422 // This allows base-class virtual access to non-virtual derived- 423 // class members of the same name. It makes the derived class more 424 // efficient to use but unsafe to further derive. 425 virtual StateId Head_() const { return Head(); } 426 virtual void Enqueue_(StateId s) { Enqueue(s); } 427 virtual void Dequeue_() { Dequeue(); } 428 virtual void Update_(StateId s) { Update(s); } 429 virtual bool Empty_() const { return Empty(); } 430 virtual void Clear_() { return Clear(); } 431 432 }; 433 434 435 // SCC topological-order meta-queue discipline, templated on the StateId S 436 // and a queue Q, which is used inside each SCC. It visits the SCC's 437 // of an FST in topological order. Its constructor is passed the queues to 438 // to use within an SCC. 439 template <class S, class Q> 440 class SccQueue : public QueueBase<S> { 441 public: 442 typedef S StateId; 443 typedef Q Queue; 444 445 // Constructor takes a vector specifying the SCC number per state 446 // and a vector giving the queue to use per SCC number. 447 SccQueue(const vector<StateId> &scc, vector<Queue*> *queue) 448 : QueueBase<S>(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), 449 back_(kNoStateId) {} 450 451 StateId Head() const { 452 while ((front_ <= back_) && 453 (((*queue_)[front_] && (*queue_)[front_]->Empty()) 454 || (((*queue_)[front_] == 0) && 455 ((front_ >= trivial_queue_.size()) 456 || (trivial_queue_[front_] == kNoStateId))))) 457 ++front_; 458 if ((*queue_)[front_]) 459 return (*queue_)[front_]->Head(); 460 else 461 return trivial_queue_[front_]; 462 } 463 464 void Enqueue(StateId s) { 465 if (front_ > back_) front_ = back_ = scc_[s]; 466 else if (scc_[s] > back_) back_ = scc_[s]; 467 else if (scc_[s] < front_) front_ = scc_[s]; 468 if ((*queue_)[scc_[s]]) { 469 (*queue_)[scc_[s]]->Enqueue(s); 470 } else { 471 while (trivial_queue_.size() <= scc_[s]) 472 trivial_queue_.push_back(kNoStateId); 473 trivial_queue_[scc_[s]] = s; 474 } 475 } 476 477 void Dequeue() { 478 if ((*queue_)[front_]) 479 (*queue_)[front_]->Dequeue(); 480 else if (front_ < trivial_queue_.size()) 481 trivial_queue_[front_] = kNoStateId; 482 } 483 484 void Update(StateId s) { 485 if ((*queue_)[scc_[s]]) 486 (*queue_)[scc_[s]]->Update(s); 487 } 488 489 bool Empty() const { 490 if (front_ < back_) // Queue scc # back_ not empty unless back_==front_ 491 return false; 492 else if (front_ > back_) 493 return true; 494 else if ((*queue_)[front_]) 495 return (*queue_)[front_]->Empty(); 496 else 497 return (front_ >= trivial_queue_.size()) 498 || (trivial_queue_[front_] == kNoStateId); 499 } 500 501 void Clear() { 502 for (StateId i = front_; i <= back_; ++i) 503 if ((*queue_)[i]) 504 (*queue_)[i]->Clear(); 505 else if (i < trivial_queue_.size()) 506 trivial_queue_[i] = kNoStateId; 507 front_ = 0; 508 back_ = kNoStateId; 509 } 510 511 private: 512 vector<Queue*> *queue_; 513 const vector<StateId> &scc_; 514 mutable StateId front_; 515 StateId back_; 516 vector<StateId> trivial_queue_; 517 518 // This allows base-class virtual access to non-virtual derived- 519 // class members of the same name. It makes the derived class more 520 // efficient to use but unsafe to further derive. 521 virtual StateId Head_() const { return Head(); } 522 virtual void Enqueue_(StateId s) { Enqueue(s); } 523 virtual void Dequeue_() { Dequeue(); } 524 virtual void Update_(StateId s) { Update(s); } 525 virtual bool Empty_() const { return Empty(); } 526 virtual void Clear_() { return Clear(); } 527 528 DISALLOW_COPY_AND_ASSIGN(SccQueue); 529 }; 530 531 532 // Automatic queue discipline, templated on the StateId. It selects a 533 // queue discipline for a given FST based on its properties. 534 template <class S> 535 class AutoQueue : public QueueBase<S> { 536 public: 537 typedef S StateId; 538 539 // This constructor takes a state distance vector that, if non-null and if 540 // the Weight type has the path property, will entertain the 541 // shortest-first queue using the natural order w.r.t to the distance. 542 template <class Arc, class ArcFilter> 543 AutoQueue(const Fst<Arc> &fst, const vector<typename Arc::Weight> *distance, 544 ArcFilter filter) : QueueBase<S>(AUTO_QUEUE) { 545 typedef typename Arc::Weight Weight; 546 typedef StateWeightCompare< StateId, NaturalLess<Weight> > Compare; 547 548 // First check if the FST is known to have these properties. 549 uint64 props = fst.Properties(kAcyclic | kCyclic | 550 kTopSorted | kUnweighted, false); 551 if ((props & kTopSorted) || fst.Start() == kNoStateId) { 552 queue_ = new StateOrderQueue<StateId>(); 553 VLOG(2) << "AutoQueue: using state-order discipline"; 554 } else if (props & kAcyclic) { 555 queue_ = new TopOrderQueue<StateId>(fst, filter); 556 VLOG(2) << "AutoQueue: using top-order discipline"; 557 } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { 558 queue_ = new LifoQueue<StateId>(); 559 VLOG(2) << "AutoQueue: using LIFO discipline"; 560 } else { 561 uint64 properties; 562 // Decompose into strongly-connected components. 563 SccVisitor<Arc> scc_visitor(&scc_, 0, 0, &properties); 564 DfsVisit(fst, &scc_visitor, filter); 565 StateId nscc = *max_element(scc_.begin(), scc_.end()) + 1; 566 vector<QueueType> queue_types(nscc); 567 NaturalLess<Weight> *less = 0; 568 Compare *comp = 0; 569 if (distance && (Weight::Properties() & kPath)) { 570 less = new NaturalLess<Weight>; 571 comp = new Compare(*distance, *less); 572 } 573 // Find the queue type to use per SCC. 574 bool unweighted; 575 bool all_trivial; 576 SccQueueType(fst, scc_, &queue_types, filter, less, &all_trivial, 577 &unweighted); 578 // If unweighted and semiring is idempotent, use lifo queue. 579 if (unweighted) { 580 queue_ = new LifoQueue<StateId>(); 581 VLOG(2) << "AutoQueue: using LIFO discipline"; 582 delete comp; 583 delete less; 584 return; 585 } 586 // If all the scc are trivial, FST is acyclic and the scc# gives 587 // the topological order. 588 if (all_trivial) { 589 queue_ = new TopOrderQueue<StateId>(scc_); 590 VLOG(2) << "AutoQueue: using top-order discipline"; 591 delete comp; 592 delete less; 593 return; 594 } 595 VLOG(2) << "AutoQueue: using SCC meta-discipline"; 596 queues_.resize(nscc); 597 for (StateId i = 0; i < nscc; ++i) { 598 switch(queue_types[i]) { 599 case TRIVIAL_QUEUE: 600 queues_[i] = 0; 601 VLOG(3) << "AutoQueue: SCC #" << i 602 << ": using trivial discipline"; 603 break; 604 case SHORTEST_FIRST_QUEUE: 605 queues_[i] = new ShortestFirstQueue<StateId, Compare, false>(*comp); 606 VLOG(3) << "AutoQueue: SCC #" << i << 607 ": using shortest-first discipline"; 608 break; 609 case LIFO_QUEUE: 610 queues_[i] = new LifoQueue<StateId>(); 611 VLOG(3) << "AutoQueue: SCC #" << i 612 << ": using LIFO disciplle"; 613 break; 614 case FIFO_QUEUE: 615 default: 616 queues_[i] = new FifoQueue<StateId>(); 617 VLOG(3) << "AutoQueue: SCC #" << i 618 << ": using FIFO disciplle"; 619 break; 620 } 621 } 622 queue_ = new SccQueue< StateId, QueueBase<StateId> >(scc_, &queues_); 623 delete comp; 624 delete less; 625 } 626 } 627 628 ~AutoQueue() { 629 for (StateId i = 0; i < queues_.size(); ++i) 630 delete queues_[i]; 631 delete queue_; 632 } 633 634 StateId Head() const { return queue_->Head(); } 635 636 void Enqueue(StateId s) { queue_->Enqueue(s); } 637 638 void Dequeue() { queue_->Dequeue(); } 639 640 void Update(StateId s) { queue_->Update(s); } 641 642 bool Empty() const { return queue_->Empty(); } 643 644 void Clear() { queue_->Clear(); } 645 646 647 private: 648 QueueBase<StateId> *queue_; 649 vector< QueueBase<StateId>* > queues_; 650 vector<StateId> scc_; 651 652 template <class Arc, class ArcFilter, class Less> 653 static void SccQueueType(const Fst<Arc> &fst, 654 const vector<StateId> &scc, 655 vector<QueueType> *queue_types, 656 ArcFilter filter, Less *less, 657 bool *all_trivial, bool *unweighted); 658 659 // This allows base-class virtual access to non-virtual derived- 660 // class members of the same name. It makes the derived class more 661 // efficient to use but unsafe to further derive. 662 virtual StateId Head_() const { return Head(); } 663 664 virtual void Enqueue_(StateId s) { Enqueue(s); } 665 666 virtual void Dequeue_() { Dequeue(); } 667 668 virtual void Update_(StateId s) { Update(s); } 669 670 virtual bool Empty_() const { return Empty(); } 671 672 virtual void Clear_() { return Clear(); } 673 674 DISALLOW_COPY_AND_ASSIGN(AutoQueue); 675 }; 676 677 678 // Examines the states in an Fst's strongly connected components and 679 // determines which type of queue to use per SCC. Stores result in 680 // vector QUEUE_TYPES, which is assumed to have length equal to the 681 // number of SCCs. An arc filter is used to limit the transitions 682 // considered (e.g., only the epsilon graph). ALL_TRIVIAL is set 683 // to true if every queue is the trivial queue. UNWEIGHTED is set to 684 // true if the semiring is idempotent and all the arc weights are equal to 685 // Zero() or One(). 686 template <class StateId> 687 template <class A, class ArcFilter, class Less> 688 void AutoQueue<StateId>::SccQueueType(const Fst<A> &fst, 689 const vector<StateId> &scc, 690 vector<QueueType> *queue_type, 691 ArcFilter filter, Less *less, 692 bool *all_trivial, bool *unweighted) { 693 typedef A Arc; 694 typedef typename A::StateId StateId; 695 typedef typename A::Weight Weight; 696 697 *all_trivial = true; 698 *unweighted = true; 699 700 for (StateId i = 0; i < queue_type->size(); ++i) 701 (*queue_type)[i] = TRIVIAL_QUEUE; 702 703 for (StateIterator< Fst<Arc> > sit(fst); !sit.Done(); sit.Next()) { 704 StateId state = sit.Value(); 705 for (ArcIterator< Fst<Arc> > ait(fst, state); 706 !ait.Done(); 707 ait.Next()) { 708 const Arc &arc = ait.Value(); 709 if (!filter(arc)) continue; 710 if (scc[state] == scc[arc.nextstate]) { 711 QueueType &type = (*queue_type)[scc[state]]; 712 if (!less || ((*less)(arc.weight, Weight::One()))) 713 type = FIFO_QUEUE; 714 else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { 715 if (!(Weight::Properties() & kIdempotent) || 716 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) 717 type = SHORTEST_FIRST_QUEUE; 718 else 719 type = LIFO_QUEUE; 720 } 721 if (type != TRIVIAL_QUEUE) *all_trivial = false; 722 } 723 if (!(Weight::Properties() & kIdempotent) || 724 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) 725 *unweighted = false; 726 } 727 } 728 } 729 730 731 // An A* estimate is a function object that maps from a state ID to a 732 // an estimate of the shortest distance to the final states. 733 // The trivial A* estimate is always One(). 734 template <typename S, typename W> 735 struct TrivialAStarEstimate { 736 W operator()(S s) const { return W::One(); } 737 }; 738 739 740 // Given a vector that maps from states to weights representing the 741 // shortest distance from the initial state, a Less comparison 742 // function object between weights, and an estimate E of the 743 // shortest distance to the final states, this class defines a 744 // comparison function object between states. 745 template <typename S, typename L, typename E> 746 class AStarWeightCompare { 747 public: 748 typedef L Less; 749 typedef typename L::Weight Weight; 750 typedef S StateId; 751 752 AStarWeightCompare(const vector<Weight>& weights, const L &less, 753 const E &estimate) 754 : weights_(weights), less_(less), estimate_(estimate) {} 755 756 bool operator()(const S x, const S y) const { 757 Weight wx = Times(weights_[x], estimate_(x)); 758 Weight wy = Times(weights_[y], estimate_(y)); 759 return less_(wx, wy); 760 } 761 762 private: 763 const vector<Weight>& weights_; 764 L less_; 765 const E &estimate_; 766 }; 767 768 769 // A* queue discipline, templated on the StateId, Weight and an 770 // estimate E of the shortest distance to the final states, is specialized 771 // to use the weight's natural order for the comparison function. 772 template <typename S, typename W, typename E> 773 class NaturalAStarQueue : 774 public ShortestFirstQueue<S, AStarWeightCompare<S, NaturalLess<W>, E> > { 775 public: 776 typedef AStarWeightCompare<S, NaturalLess<W>, E> C; 777 778 NaturalAStarQueue(const vector<W> &distance, const E &estimate) : 779 ShortestFirstQueue<S, C>(C(distance, less_, estimate)) {} 780 781 private: 782 NaturalLess<W> less_; 783 }; 784 785 786 // A state equivalence class is a function object that 787 // maps from a state ID to an equivalence class (state) ID. 788 // The trivial equivalence class maps a state to itself. 789 template <typename S> 790 struct TrivialStateEquivClass { 791 S operator()(S s) const { return s; } 792 }; 793 794 795 // Distance-based pruning queue discipline: Enqueues a state 's' 796 // only when its shortest distance (so far), as specified by 797 // 'distance', is less than (as specified by 'comp') the shortest 798 // distance Times() the 'threshold' to any state in the same 799 // equivalence class, as specified by the function object 800 // 'class_func'. The underlying queue discipline is specified by 801 // 'queue'. The ownership of 'queue' is given to this class. 802 template <typename Q, typename L, typename C> 803 class PruneQueue : public QueueBase<typename Q::StateId> { 804 public: 805 typedef typename Q::StateId StateId; 806 typedef typename L::Weight Weight; 807 808 PruneQueue(const vector<Weight> &distance, Q *queue, L comp, 809 const C &class_func, Weight threshold) 810 : QueueBase<StateId>(OTHER_QUEUE), 811 distance_(distance), 812 queue_(queue), 813 less_(comp), 814 class_func_(class_func), 815 threshold_(threshold) {} 816 817 ~PruneQueue() { delete queue_; } 818 819 StateId Head() const { return queue_->Head(); } 820 821 void Enqueue(StateId s) { 822 StateId c = class_func_(s); 823 if (c >= class_distance_.size()) 824 class_distance_.resize(c + 1, Weight::Zero()); 825 if (less_(distance_[s], class_distance_[c])) 826 class_distance_[c] = distance_[s]; 827 828 // Enqueue only if below threshold limit 829 Weight limit = Times(class_distance_[c], threshold_); 830 if (less_(distance_[s], limit)) 831 queue_->Enqueue(s); 832 } 833 834 void Dequeue() { queue_->Dequeue(); } 835 836 void Update(StateId s) { 837 StateId c = class_func_(s); 838 if (less_(distance_[s], class_distance_[c])) 839 class_distance_[c] = distance_[s]; 840 queue_->Update(s); 841 } 842 843 bool Empty() const { return queue_->Empty(); } 844 void Clear() { queue_->Clear(); } 845 846 private: 847 // This allows base-class virtual access to non-virtual derived- 848 // class members of the same name. It makes the derived class more 849 // efficient to use but unsafe to further derive. 850 virtual StateId Head_() const { return Head(); } 851 virtual void Enqueue_(StateId s) { Enqueue(s); } 852 virtual void Dequeue_() { Dequeue(); } 853 virtual void Update_(StateId s) { Update(s); } 854 virtual bool Empty_() const { return Empty(); } 855 virtual void Clear_() { return Clear(); } 856 857 const vector<Weight> &distance_; // shortest distance to state 858 Q *queue_; 859 L less_; 860 const C &class_func_; // eqv. class function object 861 Weight threshold_; // pruning weight threshold 862 vector<Weight> class_distance_; // shortest distance to class 863 864 DISALLOW_COPY_AND_ASSIGN(PruneQueue); 865 }; 866 867 868 // Pruning queue discipline (see above) using the weight's natural 869 // order for the comparison function. The ownership of 'queue' is 870 // given to this class. 871 template <typename Q, typename W, typename C> 872 class NaturalPruneQueue : 873 public PruneQueue<Q, NaturalLess<W>, C> { 874 public: 875 typedef typename Q::StateId StateId; 876 typedef W Weight; 877 878 NaturalPruneQueue(const vector<W> &distance, Q *queue, 879 const C &class_func_, Weight threshold) : 880 PruneQueue<Q, NaturalLess<W>, C>(distance, queue, less_, 881 class_func_, threshold) {} 882 883 private: 884 NaturalLess<W> less_; 885 }; 886 887 888 // Filter-based pruning queue discipline: Enqueues a state 's' only 889 // if allowed by the filter, specified by the function object 'state_filter'. 890 // The underlying queue discipline is specified by 'queue'. The ownership 891 // of 'queue' is given to this class. 892 template <typename Q, typename F> 893 class FilterQueue : public QueueBase<typename Q::StateId> { 894 public: 895 typedef typename Q::StateId StateId; 896 897 FilterQueue(Q *queue, const F &state_filter) 898 : QueueBase<StateId>(OTHER_QUEUE), 899 queue_(queue), 900 state_filter_(state_filter) {} 901 902 ~FilterQueue() { delete queue_; } 903 904 StateId Head() const { return queue_->Head(); } 905 906 // Enqueues only if allowed by state filter. 907 void Enqueue(StateId s) { 908 if (state_filter_(s)) { 909 queue_->Enqueue(s); 910 } 911 } 912 913 void Dequeue() { queue_->Dequeue(); } 914 915 void Update(StateId s) {} 916 bool Empty() const { return queue_->Empty(); } 917 void Clear() { queue_->Clear(); } 918 919 private: 920 // This allows base-class virtual access to non-virtual derived- 921 // class members of the same name. It makes the derived class more 922 // efficient to use but unsafe to further derive. 923 virtual StateId Head_() const { return Head(); } 924 virtual void Enqueue_(StateId s) { Enqueue(s); } 925 virtual void Dequeue_() { Dequeue(); } 926 virtual void Update_(StateId s) { Update(s); } 927 virtual bool Empty_() const { return Empty(); } 928 virtual void Clear_() { return Clear(); } 929 930 Q *queue_; 931 const F &state_filter_; // Filter to prune states 932 933 DISALLOW_COPY_AND_ASSIGN(FilterQueue); 934 }; 935 936 } // namespace fst 937 938 #endif 939