Home | History | Annotate | Download | only in pdt
      1 // expand.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Expand a PDT to an FST.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_EXPAND_H__
     22 #define FST_EXTENSIONS_PDT_EXPAND_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/extensions/pdt/pdt.h>
     28 #include <fst/extensions/pdt/paren.h>
     29 #include <fst/extensions/pdt/shortest-path.h>
     30 #include <fst/extensions/pdt/reverse.h>
     31 #include <fst/cache.h>
     32 #include <fst/mutable-fst.h>
     33 #include <fst/queue.h>
     34 #include <fst/state-table.h>
     35 #include <fst/test-properties.h>
     36 
     37 namespace fst {
     38 
     39 template <class Arc>
     40 struct ExpandFstOptions : public CacheOptions {
     41   bool keep_parentheses;
     42   PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
     43   PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
     44 
     45   ExpandFstOptions(
     46       const CacheOptions &opts = CacheOptions(),
     47       bool kp = false,
     48       PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
     49       PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
     50       : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
     51 };
     52 
     53 // Properties for an expanded PDT.
     54 inline uint64 ExpandProperties(uint64 inprops) {
     55   return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
     56 }
     57 
     58 
     59 // Implementation class for ExpandFst
     60 template <class A>
     61 class ExpandFstImpl
     62     : public CacheImpl<A> {
     63  public:
     64   using FstImpl<A>::SetType;
     65   using FstImpl<A>::SetProperties;
     66   using FstImpl<A>::Properties;
     67   using FstImpl<A>::SetInputSymbols;
     68   using FstImpl<A>::SetOutputSymbols;
     69 
     70   using CacheBaseImpl< CacheState<A> >::PushArc;
     71   using CacheBaseImpl< CacheState<A> >::HasArcs;
     72   using CacheBaseImpl< CacheState<A> >::HasFinal;
     73   using CacheBaseImpl< CacheState<A> >::HasStart;
     74   using CacheBaseImpl< CacheState<A> >::SetArcs;
     75   using CacheBaseImpl< CacheState<A> >::SetFinal;
     76   using CacheBaseImpl< CacheState<A> >::SetStart;
     77 
     78   typedef A Arc;
     79   typedef typename A::Label Label;
     80   typedef typename A::Weight Weight;
     81   typedef typename A::StateId StateId;
     82   typedef StateId StackId;
     83   typedef PdtStateTuple<StateId, StackId> StateTuple;
     84 
     85   ExpandFstImpl(const Fst<A> &fst,
     86                 const vector<pair<typename Arc::Label,
     87                                   typename Arc::Label> > &parens,
     88                 const ExpandFstOptions<A> &opts)
     89       : CacheImpl<A>(opts), fst_(fst.Copy()),
     90         stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
     91         state_table_(opts.state_table ? opts.state_table :
     92                      new PdtStateTable<StateId, StackId>()),
     93         own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
     94         keep_parentheses_(opts.keep_parentheses) {
     95     SetType("expand");
     96 
     97     uint64 props = fst.Properties(kFstProperties, false);
     98     SetProperties(ExpandProperties(props), kCopyProperties);
     99 
    100     SetInputSymbols(fst.InputSymbols());
    101     SetOutputSymbols(fst.OutputSymbols());
    102   }
    103 
    104   ExpandFstImpl(const ExpandFstImpl &impl)
    105       : CacheImpl<A>(impl),
    106         fst_(impl.fst_->Copy(true)),
    107         stack_(new PdtStack<StateId, Label>(*impl.stack_)),
    108         state_table_(new PdtStateTable<StateId, StackId>()),
    109         own_stack_(true), own_state_table_(true),
    110         keep_parentheses_(impl.keep_parentheses_) {
    111     SetType("expand");
    112     SetProperties(impl.Properties(), kCopyProperties);
    113     SetInputSymbols(impl.InputSymbols());
    114     SetOutputSymbols(impl.OutputSymbols());
    115   }
    116 
    117   ~ExpandFstImpl() {
    118     delete fst_;
    119     if (own_stack_)
    120       delete stack_;
    121     if (own_state_table_)
    122       delete state_table_;
    123   }
    124 
    125   StateId Start() {
    126     if (!HasStart()) {
    127       StateId s = fst_->Start();
    128       if (s == kNoStateId)
    129         return kNoStateId;
    130       StateTuple tuple(s, 0);
    131       StateId start = state_table_->FindState(tuple);
    132       SetStart(start);
    133     }
    134     return CacheImpl<A>::Start();
    135   }
    136 
    137   Weight Final(StateId s) {
    138     if (!HasFinal(s)) {
    139       const StateTuple &tuple = state_table_->Tuple(s);
    140       Weight w = fst_->Final(tuple.state_id);
    141       if (w != Weight::Zero() && tuple.stack_id == 0)
    142         SetFinal(s, w);
    143       else
    144         SetFinal(s, Weight::Zero());
    145     }
    146     return CacheImpl<A>::Final(s);
    147   }
    148 
    149   size_t NumArcs(StateId s) {
    150     if (!HasArcs(s)) {
    151       ExpandState(s);
    152     }
    153     return CacheImpl<A>::NumArcs(s);
    154   }
    155 
    156   size_t NumInputEpsilons(StateId s) {
    157     if (!HasArcs(s))
    158       ExpandState(s);
    159     return CacheImpl<A>::NumInputEpsilons(s);
    160   }
    161 
    162   size_t NumOutputEpsilons(StateId s) {
    163     if (!HasArcs(s))
    164       ExpandState(s);
    165     return CacheImpl<A>::NumOutputEpsilons(s);
    166   }
    167 
    168   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    169     if (!HasArcs(s))
    170       ExpandState(s);
    171     CacheImpl<A>::InitArcIterator(s, data);
    172   }
    173 
    174   // Computes the outgoing transitions from a state, creating new destination
    175   // states as needed.
    176   void ExpandState(StateId s) {
    177     StateTuple tuple = state_table_->Tuple(s);
    178     for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
    179          !aiter.Done(); aiter.Next()) {
    180       Arc arc = aiter.Value();
    181       StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
    182       if (stack_id == -1) {
    183         // Non-matching close parenthesis
    184         continue;
    185       } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
    186         // Stack push/pop
    187         arc.ilabel = arc.olabel = 0;
    188       }
    189 
    190       StateTuple ntuple(arc.nextstate, stack_id);
    191       arc.nextstate = state_table_->FindState(ntuple);
    192       PushArc(s, arc);
    193     }
    194     SetArcs(s);
    195   }
    196 
    197   const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
    198 
    199   const PdtStateTable<StateId, StackId> &GetStateTable() const {
    200     return *state_table_;
    201   }
    202 
    203  private:
    204   const Fst<A> *fst_;
    205 
    206   PdtStack<StackId, Label> *stack_;
    207   PdtStateTable<StateId, StackId> *state_table_;
    208   bool own_stack_;
    209   bool own_state_table_;
    210   bool keep_parentheses_;
    211 
    212   void operator=(const ExpandFstImpl<A> &);  // disallow
    213 };
    214 
    215 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
    216 // This version is a delayed Fst.  In the PDT, some transitions are
    217 // labeled with open or close parentheses. To be interpreted as a PDT,
    218 // the parens must balance on a path. The open-close parenthesis label
    219 // pairs are passed in 'parens'. The expansion enforces the
    220 // parenthesis constraints. The PDT must be expandable as an FST.
    221 //
    222 // This class attaches interface to implementation and handles
    223 // reference counting, delegating most methods to ImplToFst.
    224 template <class A>
    225 class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
    226  public:
    227   friend class ArcIterator< ExpandFst<A> >;
    228   friend class StateIterator< ExpandFst<A> >;
    229 
    230   typedef A Arc;
    231   typedef typename A::Label Label;
    232   typedef typename A::Weight Weight;
    233   typedef typename A::StateId StateId;
    234   typedef StateId StackId;
    235   typedef CacheState<A> State;
    236   typedef ExpandFstImpl<A> Impl;
    237 
    238   ExpandFst(const Fst<A> &fst,
    239             const vector<pair<typename Arc::Label,
    240                               typename Arc::Label> > &parens)
    241       : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
    242 
    243   ExpandFst(const Fst<A> &fst,
    244             const vector<pair<typename Arc::Label,
    245                               typename Arc::Label> > &parens,
    246             const ExpandFstOptions<A> &opts)
    247       : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
    248 
    249   // See Fst<>::Copy() for doc.
    250   ExpandFst(const ExpandFst<A> &fst, bool safe = false)
    251       : ImplToFst<Impl>(fst, safe) {}
    252 
    253   // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
    254   virtual ExpandFst<A> *Copy(bool safe = false) const {
    255     return new ExpandFst<A>(*this, safe);
    256   }
    257 
    258   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    259 
    260   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    261     GetImpl()->InitArcIterator(s, data);
    262   }
    263 
    264   const PdtStack<StackId, Label> &GetStack() const {
    265     return GetImpl()->GetStack();
    266   }
    267 
    268   const PdtStateTable<StateId, StackId> &GetStateTable() const {
    269     return GetImpl()->GetStateTable();
    270   }
    271 
    272  private:
    273   // Makes visible to friends.
    274   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    275 
    276   void operator=(const ExpandFst<A> &fst);  // Disallow
    277 };
    278 
    279 
    280 // Specialization for ExpandFst.
    281 template<class A>
    282 class StateIterator< ExpandFst<A> >
    283     : public CacheStateIterator< ExpandFst<A> > {
    284  public:
    285   explicit StateIterator(const ExpandFst<A> &fst)
    286       : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
    287 };
    288 
    289 
    290 // Specialization for ExpandFst.
    291 template <class A>
    292 class ArcIterator< ExpandFst<A> >
    293     : public CacheArcIterator< ExpandFst<A> > {
    294  public:
    295   typedef typename A::StateId StateId;
    296 
    297   ArcIterator(const ExpandFst<A> &fst, StateId s)
    298       : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
    299     if (!fst.GetImpl()->HasArcs(s))
    300       fst.GetImpl()->ExpandState(s);
    301   }
    302 
    303  private:
    304   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    305 };
    306 
    307 
    308 template <class A> inline
    309 void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
    310 {
    311   data->base = new StateIterator< ExpandFst<A> >(*this);
    312 }
    313 
    314 //
    315 // PrunedExpand Class
    316 //
    317 
    318 // Prunes the delayed expansion of a pushdown transducer (PDT) encoded
    319 // as an FST into an FST.  In the PDT, some transitions are labeled
    320 // with open or close parentheses. To be interpreted as a PDT, the
    321 // parens must balance on a path. The open-close parenthesis label
    322 // pairs are passed in 'parens'. The expansion enforces the
    323 // parenthesis constraints.
    324 //
    325 // The algorithm works by visiting the delayed ExpandFst using a
    326 // shortest-stack first queue discipline and relies on the
    327 // shortest-distance information computed using a reverse
    328 // shortest-path call to perform the pruning.
    329 //
    330 // The algorithm maintains the same state ordering between the ExpandFst
    331 // being visited 'efst_' and the result of pruning written into the
    332 // MutableFst 'ofst_' to improve readability of the code.
    333 //
    334 template <class A>
    335 class PrunedExpand {
    336  public:
    337   typedef A Arc;
    338   typedef typename A::Label Label;
    339   typedef typename A::StateId StateId;
    340   typedef typename A::Weight Weight;
    341   typedef StateId StackId;
    342   typedef PdtStack<StackId, Label> Stack;
    343   typedef PdtStateTable<StateId, StackId> StateTable;
    344   typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
    345 
    346   // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
    347   // 'keep_parentheses' specifies whether parentheses are replaced by
    348   // epsilons or not during the expansion. 'opts' is the cache options
    349   // used to instantiate the underlying ExpandFst.
    350   PrunedExpand(const Fst<A> &ifst,
    351                const vector<pair<Label, Label> > &parens,
    352                bool keep_parentheses = false,
    353                const CacheOptions &opts = CacheOptions())
    354       : ifst_(ifst.Copy()),
    355         keep_parentheses_(keep_parentheses),
    356         stack_(parens),
    357         efst_(ifst, parens,
    358               ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
    359         queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
    360     Reverse(*ifst_, parens, &rfst_);
    361     VectorFst<Arc> path;
    362     reverse_shortest_path_ = new SP(
    363         rfst_, parens,
    364         PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
    365     reverse_shortest_path_->ShortestPath(&path);
    366     balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
    367         rfst_.NumStates(), 10, -1);
    368 
    369     InitCloseParenMultimap(parens);
    370   }
    371 
    372   ~PrunedExpand() {
    373     delete ifst_;
    374     delete reverse_shortest_path_;
    375     delete balance_data_;
    376   }
    377 
    378   // Expands and prunes with weight threshold 'threshold' the input PDT.
    379   // Writes the result in 'ofst'.
    380   void Expand(MutableFst<A> *ofst, const Weight &threshold);
    381 
    382  private:
    383   static const uint8 kEnqueued;
    384   static const uint8 kExpanded;
    385   static const uint8 kSourceState;
    386 
    387   // Comparison functor used by the queue:
    388   // 1. states corresponding to shortest stack first,
    389   // 2. among stacks of the same length, reverse lexicographic order is used,
    390   // 3. among states with the same stack, shortest-first order is used.
    391   class StackCompare {
    392    public:
    393     StackCompare(const StateTable &st,
    394                  const Stack &s, const vector<StackId> &sl,
    395                  const vector<Weight> &d, const vector<Weight> &fd)
    396         : state_table_(st), stack_(s), stack_length_(sl),
    397           distance_(d), fdistance_(fd) {}
    398 
    399     bool operator()(StateId s1, StateId s2) const {
    400       StackId si1 = state_table_.Tuple(s1).stack_id;
    401       StackId si2 = state_table_.Tuple(s2).stack_id;
    402       if (stack_length_[si1] < stack_length_[si2])
    403         return true;
    404       if  (stack_length_[si1] > stack_length_[si2])
    405         return false;
    406       // If stack id equal, use A*
    407       if (si1 == si2) {
    408         Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
    409             Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
    410         Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
    411             Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
    412         return less_(w1, w2);
    413       }
    414       // If lenghts are equal, use reverse lexico.
    415       for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
    416         if (stack_.Top(si1) < stack_.Top(si2)) return true;
    417         if (stack_.Top(si1) > stack_.Top(si2)) return false;
    418       }
    419       return false;
    420     }
    421 
    422    private:
    423     const StateTable &state_table_;
    424     const Stack &stack_;
    425     const vector<StackId> &stack_length_;
    426     const vector<Weight> &distance_;
    427     const vector<Weight> &fdistance_;
    428     NaturalLess<Weight> less_;
    429   };
    430 
    431   class ShortestStackFirstQueue
    432       : public ShortestFirstQueue<StateId, StackCompare> {
    433    public:
    434     ShortestStackFirstQueue(
    435         const PdtStateTable<StateId, StackId> &st,
    436         const Stack &s,
    437         const vector<StackId> &sl,
    438         const vector<Weight> &d, const vector<Weight> &fd)
    439         : ShortestFirstQueue<StateId, StackCompare>(
    440             StackCompare(st, s, sl, d, fd)) {}
    441   };
    442 
    443 
    444   void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
    445   Weight DistanceToDest(StateId state, StateId source) const;
    446   uint8 Flags(StateId s) const;
    447   void SetFlags(StateId s, uint8 flags, uint8 mask);
    448   Weight Distance(StateId s) const;
    449   void SetDistance(StateId s, Weight w);
    450   Weight FinalDistance(StateId s) const;
    451   void SetFinalDistance(StateId s, Weight w);
    452   StateId SourceState(StateId s) const;
    453   void SetSourceState(StateId s, StateId p);
    454   void AddStateAndEnqueue(StateId s);
    455   void Relax(StateId s, const A &arc, Weight w);
    456   bool PruneArc(StateId s, const A &arc);
    457   void ProcStart();
    458   void ProcFinal(StateId s);
    459   bool ProcNonParen(StateId s, const A &arc, bool add_arc);
    460   bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
    461   bool ProcCloseParen(StateId s, const A &arc);
    462   void ProcDestStates(StateId s, StackId si);
    463 
    464   Fst<A> *ifst_;                   // Input PDT
    465   VectorFst<Arc> rfst_;            // Reversed PDT
    466   bool keep_parentheses_;          // Keep parentheses in ofst?
    467   StateTable state_table_;         // State table for efst_
    468   Stack stack_;                    // Stack trie
    469   ExpandFst<Arc> efst_;            // Expanded PDT
    470   vector<StackId> stack_length_;   // Length of stack for given stack id
    471   vector<Weight> distance_;        // Distance from initial state in efst_/ofst
    472   vector<Weight> fdistance_;       // Distance to final states in efst_/ofst
    473   ShortestStackFirstQueue queue_;  // Queue used to visit efst_
    474   vector<uint8> flags_;            // Status flags for states in efst_/ofst
    475   vector<StateId> sources_;        // PDT source state for each expanded state
    476 
    477   typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
    478   typedef typename SP::CloseParenMultimap ParenMultimap;
    479   SP *reverse_shortest_path_;  // Shortest path for rfst_
    480   PdtBalanceData<Arc> *balance_data_;   // Not owned by shortest_path_
    481   ParenMultimap close_paren_multimap_;  // Maps open paren arcs to
    482   // balancing close paren arcs.
    483 
    484   MutableFst<Arc> *ofst_;  // Output fst
    485   Weight limit_;           // Weight limit
    486 
    487   typedef unordered_map<StateId, Weight> DestMap;
    488   DestMap dest_map_;
    489   StackId current_stack_id_;
    490   // 'current_stack_id_' is the stack id of the states currently at the top
    491   // of queue, i.e., the states currently being popped and processed.
    492   // 'dest_map_' maps a state 's' in 'ifst_' that is the source
    493   // of a close parentheses matching the top of 'current_stack_id_; to
    494   // the shortest-distance from '(s, current_stack_id_)' to the final
    495   // states in 'efst_'.
    496   ssize_t current_paren_id_;  // Paren id at top of current stack
    497   ssize_t cached_stack_id_;
    498   StateId cached_source_;
    499   slist<pair<StateId, Weight> > cached_dest_list_;
    500   // 'cached_dest_list_' contains the set of pair of destination
    501   // states and weight to final states for source state
    502   // 'cached_source_' and paren id 'cached_paren_id': the set of
    503   // source state of a close parenthesis with paren id
    504   // 'cached_paren_id' balancing an incoming open parenthesis with
    505   // paren id 'cached_paren_id' in state 'cached_source_'.
    506 
    507   NaturalLess<Weight> less_;
    508 };
    509 
    510 template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
    511 template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
    512 template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
    513 
    514 
    515 // Initializes close paren multimap, mapping pairs (s,paren_id) to
    516 // all the arcs out of s labeled with close parenthese for paren_id.
    517 template <class A>
    518 void PrunedExpand<A>::InitCloseParenMultimap(
    519     const vector<pair<Label, Label> > &parens) {
    520   unordered_map<Label, Label> paren_id_map;
    521   for (Label i = 0; i < parens.size(); ++i) {
    522     const pair<Label, Label>  &p = parens[i];
    523     paren_id_map[p.first] = i;
    524     paren_id_map[p.second] = i;
    525   }
    526 
    527   for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
    528     StateId s = siter.Value();
    529     for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
    530          !aiter.Done(); aiter.Next()) {
    531       const Arc &arc = aiter.Value();
    532       typename unordered_map<Label, Label>::const_iterator pit
    533           = paren_id_map.find(arc.ilabel);
    534       if (pit == paren_id_map.end()) continue;
    535       if (arc.ilabel == parens[pit->second].second) {  // Close paren
    536         ParenState<Arc> paren_state(pit->second, s);
    537         close_paren_multimap_.insert(make_pair(paren_state, arc));
    538       }
    539     }
    540   }
    541 }
    542 
    543 
    544 // Returns the weight of the shortest balanced path from 'source' to 'dest'
    545 // in 'ifst_', 'dest' must be the source state of a close paren arc.
    546 template <class A>
    547 typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
    548                                                    StateId dest) const {
    549   typename SP::SearchState s(source + 1, dest + 1);
    550   VLOG(2) << "D(" << source << ", " << dest << ") ="
    551             << reverse_shortest_path_->GetShortestPathData().Distance(s);
    552   return reverse_shortest_path_->GetShortestPathData().Distance(s);
    553 }
    554 
    555 // Returns the flags for state 's' in 'ofst_'.
    556 template <class A>
    557 uint8 PrunedExpand<A>::Flags(StateId s) const {
    558   return s < flags_.size() ? flags_[s] : 0;
    559 }
    560 
    561 // Modifies the flags for state 's' in 'ofst_'.
    562 template <class A>
    563 void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
    564   while (flags_.size() <= s) flags_.push_back(0);
    565   flags_[s] &= ~mask;
    566   flags_[s] |= flags & mask;
    567 }
    568 
    569 
    570 // Returns the shortest distance from the initial state to 's' in 'ofst_'.
    571 template <class A>
    572 typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
    573   return s < distance_.size() ? distance_[s] : Weight::Zero();
    574 }
    575 
    576 // Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
    577 template <class A>
    578 void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
    579   while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
    580   distance_[s] = w;
    581 }
    582 
    583 
    584 // Returns the shortest distance from 's' to the final states in 'ofst_'.
    585 template <class A>
    586 typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
    587   return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
    588 }
    589 
    590 // Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
    591 template <class A>
    592 void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
    593   while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
    594   fdistance_[s] = w;
    595 }
    596 
    597 // Returns the PDT "source" state of state 's' in 'ofst_'.
    598 template <class A>
    599 typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
    600   return s < sources_.size() ? sources_[s] : kNoStateId;
    601 }
    602 
    603 // Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
    604 template <class A>
    605 void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
    606   while (sources_.size() <= s) sources_.push_back(kNoStateId);
    607   sources_[s] = p;
    608 }
    609 
    610 // Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
    611 // modifying the flags for 's' accordingly.
    612 template <class A>
    613 void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
    614   if (!(Flags(s) & (kEnqueued | kExpanded))) {
    615     while (ofst_->NumStates() <= s) ofst_->AddState();
    616     queue_.Enqueue(s);
    617     SetFlags(s, kEnqueued, kEnqueued);
    618   } else if (Flags(s) & kEnqueued) {
    619     queue_.Update(s);
    620   }
    621   // TODO(allauzen): Check everything is fine when kExpanded?
    622 }
    623 
    624 // Relaxes arc 'arc' out of state 's' in 'ofst_':
    625 // * if the distance to 's' times the weight of 'arc' is smaller than
    626 //   the currently stored distance for 'arc.nextstate',
    627 //   updates 'Distance(arc.nextstate)' with new estimate;
    628 // * if 'fd' is less than the currently stored distance from 'arc.nextstate'
    629 //   to the final state, updates with new estimate.
    630 template <class A>
    631 void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
    632   Weight nd = Times(Distance(s), arc.weight);
    633   if (less_(nd, Distance(arc.nextstate))) {
    634     SetDistance(arc.nextstate, nd);
    635     SetSourceState(arc.nextstate, SourceState(s));
    636   }
    637   if (less_(fd, FinalDistance(arc.nextstate)))
    638     SetFinalDistance(arc.nextstate, fd);
    639   VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
    640             << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
    641             << ", nd = " << nd;
    642 }
    643 
    644 // Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
    645 // be pruned.
    646 template <class A>
    647 bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
    648   VLOG(2) << "Prune ?";
    649   Weight fd = Weight::Zero();
    650 
    651   if ((cached_source_ != SourceState(s)) ||
    652       (cached_stack_id_ != current_stack_id_)) {
    653     cached_source_ = SourceState(s);
    654     cached_stack_id_ = current_stack_id_;
    655     cached_dest_list_.clear();
    656     if (cached_source_ != ifst_->Start()) {
    657       for (SetIterator set_iter =
    658                balance_data_->Find(current_paren_id_, cached_source_);
    659            !set_iter.Done(); set_iter.Next()) {
    660         StateId dest = set_iter.Element();
    661         typename DestMap::const_iterator iter = dest_map_.find(dest);
    662         cached_dest_list_.push_front(*iter);
    663       }
    664     } else {
    665       // TODO(allauzen): queue discipline should prevent this never
    666       // from happening; replace by a check.
    667       cached_dest_list_.push_front(
    668           make_pair(rfst_.Start() -1, Weight::One()));
    669     }
    670   }
    671 
    672   for (typename slist<pair<StateId, Weight> >::const_iterator iter =
    673            cached_dest_list_.begin();
    674        iter != cached_dest_list_.end();
    675        ++iter) {
    676     fd = Plus(fd,
    677               Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
    678                                    iter->first),
    679                     iter->second));
    680   }
    681   Relax(s, arc, fd);
    682   Weight w = Times(Distance(s), Times(arc.weight, fd));
    683   return less_(limit_, w);
    684 }
    685 
    686 // Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
    687 // the distance data structures.
    688 template <class A>
    689 void PrunedExpand<A>::ProcStart() {
    690   StateId s = efst_.Start();
    691   AddStateAndEnqueue(s);
    692   ofst_->SetStart(s);
    693   SetSourceState(s, ifst_->Start());
    694 
    695   current_stack_id_ = 0;
    696   current_paren_id_ = -1;
    697   stack_length_.push_back(0);
    698   dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
    699 
    700   cached_source_ = ifst_->Start();
    701   cached_stack_id_ = 0;
    702   cached_dest_list_.push_front(
    703           make_pair(rfst_.Start() -1, Weight::One()));
    704 
    705   PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
    706   SetFinalDistance(state_table_.FindState(tuple), Weight::One());
    707   SetDistance(s, Weight::One());
    708   SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
    709   VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
    710 }
    711 
    712 // Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
    713 // is below threshold.
    714 template <class A>
    715 void PrunedExpand<A>::ProcFinal(StateId s) {
    716   Weight final = efst_.Final(s);
    717   if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
    718     return;
    719   ofst_->SetFinal(s, final);
    720 }
    721 
    722 // Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
    723 // below the threshold.  When 'add_arc' is true, 'arc' is added to 'ofst_'.
    724 template <class A>
    725 bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
    726   VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
    727           << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
    728           << ", add_arc = " << (add_arc ? "true" : "false");
    729   if (PruneArc(s, arc)) return false;
    730   if(add_arc) ofst_->AddArc(s, arc);
    731   AddStateAndEnqueue(arc.nextstate);
    732   return true;
    733 }
    734 
    735 // Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
    736 // When 'arc' is labeled with an open paren,
    737 // 1. considers each (shortest) balanced path starting in 's' by
    738 //    taking 'arc' and ending by a close paren balancing the open
    739 //    paren of 'arc' as a meta-arc, processes and prunes each meta-arc
    740 //    as a non-paren arc, inserting its destination to the queue;
    741 // 2. if at least one of these meta-arcs has not been pruned,
    742 //    adds the destination of 'arc' to 'ofst_' as a new source state
    743 //    for the stack id 'nsi' and inserts it in the queue.
    744 template <class A>
    745 bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
    746                                     StackId nsi) {
    747   // Update the stack lenght when needed: |nsi| = |si| + 1.
    748   while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
    749   if (stack_length_[nsi] == -1)
    750     stack_length_[nsi] = stack_length_[si] + 1;
    751 
    752   StateId ns = arc.nextstate;
    753   VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
    754             << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
    755   bool proc_arc = false;
    756   Weight fd = Weight::Zero();
    757   ssize_t paren_id = stack_.ParenId(arc.ilabel);
    758   slist<StateId> sources;
    759   for (SetIterator set_iter =
    760            balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
    761        !set_iter.Done(); set_iter.Next()) {
    762     sources.push_front(set_iter.Element());
    763   }
    764   for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
    765        sources_iter != sources.end();
    766        ++ sources_iter) {
    767     StateId source = *sources_iter;
    768     VLOG(2) << "Close paren source: " << source;
    769     ParenState<Arc> paren_state(paren_id, source);
    770     for (typename ParenMultimap::const_iterator iter =
    771              close_paren_multimap_.find(paren_state);
    772          iter != close_paren_multimap_.end() && paren_state == iter->first;
    773          ++iter) {
    774       Arc meta_arc = iter->second;
    775       PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
    776       meta_arc.nextstate =  state_table_.FindState(tuple);
    777       VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
    778       VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
    779                 << DistanceToDest(state_table_.Tuple(ns).state_id, source)
    780                 << " Times " << meta_arc.weight;
    781       meta_arc.weight = Times(
    782           arc.weight,
    783           Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
    784                 meta_arc.weight));
    785       proc_arc |= ProcNonParen(s, meta_arc, false);
    786       fd = Plus(fd, Times(
    787           Times(
    788               DistanceToDest(state_table_.Tuple(ns).state_id, source),
    789               iter->second.weight),
    790           FinalDistance(meta_arc.nextstate)));
    791     }
    792   }
    793   if (proc_arc) {
    794     VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
    795     ofst_->AddArc(
    796       s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
    797     AddStateAndEnqueue(arc.nextstate);
    798     Weight nd = Times(Distance(s), arc.weight);
    799     if(less_(nd, Distance(arc.nextstate)))
    800       SetDistance(arc.nextstate, nd);
    801     // FinalDistance not necessary for source state since pruning
    802     // decided using the meta-arcs above.  But this is a problem with
    803     // A*, hence:
    804     if (less_(fd, FinalDistance(arc.nextstate)))
    805       SetFinalDistance(arc.nextstate, fd);
    806     SetFlags(arc.nextstate, kSourceState, kSourceState);
    807   }
    808   return proc_arc;
    809 }
    810 
    811 // Checks that shortest path through close paren arc in 'efst_' is
    812 // below threshold, if so adds it to 'ofst_'.
    813 template <class A>
    814 bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
    815   Weight w = Times(Distance(s),
    816                    Times(arc.weight, FinalDistance(arc.nextstate)));
    817   if (less_(limit_, w))
    818     return false;
    819   ofst_->AddArc(
    820       s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
    821   return true;
    822 }
    823 
    824 // When 's' in 'ofst_' is a source state for stack id 'si', identifies
    825 // all the corresponding possible destination states, that is, all the
    826 // states in 'ifst_' that have an outgoing close paren arc balancing
    827 // the incoming open paren taken to get to 's', and for each such
    828 // state 't', computes the shortest distance from (t, si) to the final
    829 // states in 'ofst_'. Stores this information in 'dest_map_'.
    830 template <class A>
    831 void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
    832   if (!(Flags(s) & kSourceState)) return;
    833   if (si != current_stack_id_) {
    834     dest_map_.clear();
    835     current_stack_id_ = si;
    836     current_paren_id_ = stack_.Top(current_stack_id_);
    837     VLOG(2) << "StackID " << si << " dequeued for first time";
    838   }
    839   // TODO(allauzen): clean up source state business; rename current function to
    840   // ProcSourceState.
    841   SetSourceState(s, state_table_.Tuple(s).state_id);
    842 
    843   ssize_t paren_id = stack_.Top(si);
    844   for (SetIterator set_iter =
    845            balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
    846        !set_iter.Done(); set_iter.Next()) {
    847     StateId dest_state = set_iter.Element();
    848     if (dest_map_.find(dest_state) != dest_map_.end())
    849       continue;
    850     Weight dest_weight = Weight::Zero();
    851     ParenState<Arc> paren_state(paren_id, dest_state);
    852     for (typename ParenMultimap::const_iterator iter =
    853              close_paren_multimap_.find(paren_state);
    854          iter != close_paren_multimap_.end() && paren_state == iter->first;
    855          ++iter) {
    856       const Arc &arc = iter->second;
    857       PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
    858       dest_weight = Plus(dest_weight,
    859                          Times(arc.weight,
    860                                FinalDistance(state_table_.FindState(tuple))));
    861     }
    862     dest_map_[dest_state] = dest_weight;
    863     VLOG(2) << "State " << dest_state << " is a dest state for stack id "
    864               << si << " with weight " << dest_weight;
    865   }
    866 }
    867 
    868 // Expands and prunes with weight threshold 'threshold' the input PDT.
    869 // Writes the result in 'ofst'.
    870 template <class A>
    871 void PrunedExpand<A>::Expand(
    872     MutableFst<A> *ofst, const typename A::Weight &threshold) {
    873   ofst_ = ofst;
    874   ofst_->DeleteStates();
    875   ofst_->SetInputSymbols(ifst_->InputSymbols());
    876   ofst_->SetOutputSymbols(ifst_->OutputSymbols());
    877 
    878   limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
    879   flags_.clear();
    880 
    881   ProcStart();
    882 
    883   while (!queue_.Empty()) {
    884     StateId s = queue_.Head();
    885     queue_.Dequeue();
    886     SetFlags(s, kExpanded, kExpanded | kEnqueued);
    887     VLOG(2) << s << " dequeued!";
    888 
    889     ProcFinal(s);
    890     StackId stack_id = state_table_.Tuple(s).stack_id;
    891     ProcDestStates(s, stack_id);
    892 
    893     for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
    894          !aiter.Done();
    895          aiter.Next()) {
    896       Arc arc = aiter.Value();
    897       StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
    898       if (stack_id == nextstack_id)
    899         ProcNonParen(s, arc, true);
    900       else if (stack_id == stack_.Pop(nextstack_id))
    901         ProcOpenParen(s, arc, stack_id, nextstack_id);
    902       else
    903         ProcCloseParen(s, arc);
    904     }
    905     VLOG(2) << "d[" << s << "] = " << Distance(s)
    906             << ", fd[" << s << "] = " << FinalDistance(s);
    907   }
    908 }
    909 
    910 //
    911 // Expand() Functions
    912 //
    913 
    914 template <class Arc>
    915 struct ExpandOptions {
    916   bool connect;
    917   bool keep_parentheses;
    918   typename Arc::Weight weight_threshold;
    919 
    920   ExpandOptions(bool c  = true, bool k = false,
    921                 typename Arc::Weight w = Arc::Weight::Zero())
    922       : connect(c), keep_parentheses(k), weight_threshold(w) {}
    923 };
    924 
    925 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
    926 // This version writes the expanded PDT result to a MutableFst.
    927 // In the PDT, some transitions are labeled with open or close
    928 // parentheses. To be interpreted as a PDT, the parens must balance on
    929 // a path. The open-close parenthesis label pairs are passed in
    930 // 'parens'. The expansion enforces the parenthesis constraints. The
    931 // PDT must be expandable as an FST.
    932 template <class Arc>
    933 void Expand(
    934     const Fst<Arc> &ifst,
    935     const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
    936     MutableFst<Arc> *ofst,
    937     const ExpandOptions<Arc> &opts) {
    938   typedef typename Arc::Label Label;
    939   typedef typename Arc::StateId StateId;
    940   typedef typename Arc::Weight Weight;
    941   typedef typename ExpandFst<Arc>::StackId StackId;
    942 
    943   ExpandFstOptions<Arc> eopts;
    944   eopts.gc_limit = 0;
    945   if (opts.weight_threshold == Weight::Zero()) {
    946     eopts.keep_parentheses = opts.keep_parentheses;
    947     *ofst = ExpandFst<Arc>(ifst, parens, eopts);
    948   } else {
    949     PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
    950     pruned_expand.Expand(ofst, opts.weight_threshold);
    951   }
    952 
    953   if (opts.connect)
    954     Connect(ofst);
    955 }
    956 
    957 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
    958 // This version writes the expanded PDT result to a MutableFst.
    959 // In the PDT, some transitions are labeled with open or close
    960 // parentheses. To be interpreted as a PDT, the parens must balance on
    961 // a path. The open-close parenthesis label pairs are passed in
    962 // 'parens'. The expansion enforces the parenthesis constraints. The
    963 // PDT must be expandable as an FST.
    964 template<class Arc>
    965 void Expand(
    966     const Fst<Arc> &ifst,
    967     const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
    968     MutableFst<Arc> *ofst,
    969     bool connect = true, bool keep_parentheses = false) {
    970   Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
    971 }
    972 
    973 }  // namespace fst
    974 
    975 #endif  // FST_EXTENSIONS_PDT_EXPAND_H__
    976