Home | History | Annotate | Download | only in fst
      1 // visit.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Queue-dependent visitation of finite-state transducers. See also
     20 // dfs-visit.h.
     21 
     22 #ifndef FST_LIB_VISIT_H__
     23 #define FST_LIB_VISIT_H__
     24 
     25 
     26 #include <fst/arcfilter.h>
     27 #include <fst/mutable-fst.h>
     28 
     29 
     30 namespace fst {
     31 
     32 // Visitor Interface - class determines actions taken during a visit.
     33 // If any of the boolean member functions return false, the visit is
     34 // aborted by first calling FinishState() on all unfinished (grey)
     35 // states and then calling FinishVisit().
     36 //
     37 // Note this is more general than the visitor interface in
     38 // dfs-visit.h but lacks some DFS-specific behavior.
     39 //
     40 // template <class Arc>
     41 // class Visitor {
     42 //  public:
     43 //   typedef typename Arc::StateId StateId;
     44 //
     45 //   Visitor(T *return_data);
     46 //   // Invoked before visit
     47 //   void InitVisit(const Fst<Arc> &fst);
     48 //   // Invoked when state discovered (2nd arg is visitation root)
     49 //   bool InitState(StateId s, StateId root);
     50 //   // Invoked when arc to white/undiscovered state examined
     51 //   bool WhiteArc(StateId s, const Arc &a);
     52 //   // Invoked when arc to grey/unfinished state examined
     53 //   bool GreyArc(StateId s, const Arc &a);
     54 //   // Invoked when arc to black/finished state examined
     55 //   bool BlackArc(StateId s, const Arc &a);
     56 //   // Invoked when state finished.
     57 //   void FinishState(StateId s);
     58 //   // Invoked after visit
     59 //   void FinishVisit();
     60 // };
     61 
     62 // Performs queue-dependent visitation. Visitor class argument
     63 // determines actions and contains any return data. ArcFilter
     64 // determines arcs that are considered.
     65 //
     66 // Note this is more general than DfsVisit() in dfs-visit.h but lacks
     67 // some DFS-specific Visitor behavior.
     68 template <class Arc, class V, class Q, class ArcFilter>
     69 void Visit(const Fst<Arc> &fst, V *visitor, Q *queue, ArcFilter filter) {
     70 
     71   typedef typename Arc::StateId StateId;
     72   typedef ArcIterator< Fst<Arc> > AIterator;
     73 
     74   visitor->InitVisit(fst);
     75 
     76   StateId start = fst.Start();
     77   if (start == kNoStateId) {
     78     visitor->FinishVisit();
     79     return;
     80   }
     81 
     82   // An Fst state's visit color
     83   const unsigned kWhiteState =  0x01;    // Undiscovered
     84   const unsigned kGreyState =   0x02;    // Discovered & unfinished
     85   const unsigned kBlackState =  0x04;    // Finished
     86 
     87   // We destroy an iterator as soon as possible and mark it so
     88   const unsigned kArcIterDone = 0x08;      // Arc iterator done and destroyed
     89 
     90   vector<unsigned char> state_status;
     91   vector<AIterator *> arc_iterator;
     92 
     93   StateId nstates = start + 1;             // # of known states in general case
     94   bool expanded = false;
     95   if (fst.Properties(kExpanded, false)) {  // tests if expanded case, then
     96     nstates = CountStates(fst);            // uses ExpandedFst::NumStates().
     97     expanded = true;
     98   }
     99 
    100   state_status.resize(nstates, kWhiteState);
    101   arc_iterator.resize(nstates);
    102   StateIterator< Fst<Arc> > siter(fst);
    103 
    104   // Continues visit while true
    105   bool visit = true;
    106 
    107   // Iterates over trees in visit forest.
    108   for (StateId root = start; visit && root < nstates;) {
    109     visit = visitor->InitState(root, root);
    110     state_status[root] = kGreyState;
    111     queue->Enqueue(root);
    112     while (!queue->Empty()) {
    113       StateId s = queue->Head();
    114       if (s >= state_status.size()) {
    115         nstates = s + 1;
    116         state_status.resize(nstates, kWhiteState);
    117         arc_iterator.resize(nstates);
    118       }
    119       // Creates arc iterator if needed.
    120       if (arc_iterator[s] == 0 && !(state_status[s] & kArcIterDone) && visit)
    121         arc_iterator[s] = new AIterator(fst, s);
    122       // Deletes arc iterator if done.
    123       AIterator *aiter = arc_iterator[s];
    124       if ((aiter && aiter->Done()) || !visit) {
    125         delete aiter;
    126         arc_iterator[s] = 0;
    127         state_status[s] |= kArcIterDone;
    128       }
    129       // Dequeues state and marks black if done
    130       if (state_status[s] & kArcIterDone) {
    131         queue->Dequeue();
    132         visitor->FinishState(s);
    133         state_status[s] = kBlackState;
    134         continue;
    135       }
    136 
    137       const Arc &arc = aiter->Value();
    138       if (arc.nextstate >= state_status.size()) {
    139         nstates = arc.nextstate + 1;
    140         state_status.resize(nstates, kWhiteState);
    141         arc_iterator.resize(nstates);
    142       }
    143       // Visits respective arc types
    144       if (filter(arc)) {
    145         // Enqueues destination state and marks grey if white
    146         if (state_status[arc.nextstate] == kWhiteState) {
    147           visit = visitor->WhiteArc(s, arc);
    148           if (!visit) continue;
    149           visit = visitor->InitState(arc.nextstate, root);
    150           state_status[arc.nextstate] = kGreyState;
    151           queue->Enqueue(arc.nextstate);
    152         } else if (state_status[arc.nextstate] == kBlackState) {
    153           visit = visitor->BlackArc(s, arc);
    154         } else {
    155           visit = visitor->GreyArc(s, arc);
    156         }
    157       }
    158       aiter->Next();
    159       // Destroys an iterator ASAP for efficiency.
    160       if (aiter->Done()) {
    161         delete aiter;
    162         arc_iterator[s] = 0;
    163         state_status[s] |= kArcIterDone;
    164       }
    165     }
    166     // Finds next tree root
    167     for (root = root == start ? 0 : root + 1;
    168          root < nstates && state_status[root] != kWhiteState;
    169          ++root) {
    170     }
    171 
    172     // Check for a state beyond the largest known state
    173     if (!expanded && root == nstates) {
    174       for (; !siter.Done(); siter.Next()) {
    175         if (siter.Value() == nstates) {
    176           ++nstates;
    177           state_status.push_back(kWhiteState);
    178           arc_iterator.push_back(0);
    179           break;
    180         }
    181       }
    182     }
    183   }
    184   visitor->FinishVisit();
    185 }
    186 
    187 
    188 template <class Arc, class V, class Q>
    189 inline void Visit(const Fst<Arc> &fst, V *visitor, Q* queue) {
    190   Visit(fst, visitor, queue, AnyArcFilter<Arc>());
    191 }
    192 
    193 // Copies input FST to mutable FST following queue order.
    194 template <class A>
    195 class CopyVisitor {
    196  public:
    197   typedef A Arc;
    198   typedef typename A::StateId StateId;
    199 
    200   CopyVisitor(MutableFst<Arc> *ofst) : ifst_(0), ofst_(ofst) {}
    201 
    202   void InitVisit(const Fst<A> &ifst) {
    203     ifst_ = &ifst;
    204     ofst_->DeleteStates();
    205     ofst_->SetStart(ifst_->Start());
    206   }
    207 
    208   bool InitState(StateId s, StateId) {
    209     while (ofst_->NumStates() <= s)
    210       ofst_->AddState();
    211     return true;
    212   }
    213 
    214   bool WhiteArc(StateId s, const Arc &arc) {
    215     ofst_->AddArc(s, arc);
    216     return true;
    217   }
    218 
    219   bool GreyArc(StateId s, const Arc &arc) {
    220     ofst_->AddArc(s, arc);
    221     return true;
    222   }
    223 
    224   bool BlackArc(StateId s, const Arc &arc) {
    225     ofst_->AddArc(s, arc);
    226     return true;
    227   }
    228 
    229   void FinishState(StateId s) {
    230     ofst_->SetFinal(s, ifst_->Final(s));
    231   }
    232 
    233   void FinishVisit() {}
    234 
    235  private:
    236   const Fst<Arc> *ifst_;
    237   MutableFst<Arc> *ofst_;
    238 };
    239 
    240 
    241 // Visits input FST up to a state limit following queue order. If
    242 // 'access_only' is true, aborts on visiting first state not
    243 // accessible from the initial state.
    244 template <class A>
    245 class PartialVisitor {
    246  public:
    247   typedef A Arc;
    248   typedef typename A::StateId StateId;
    249 
    250   explicit PartialVisitor(StateId maxvisit, bool access_only = false)
    251       : maxvisit_(maxvisit),
    252         access_only_(access_only),
    253         start_(kNoStateId) {}
    254 
    255   void InitVisit(const Fst<A> &ifst) {
    256     nvisit_ = 0;
    257     start_ = ifst.Start();
    258   }
    259 
    260   bool InitState(StateId s, StateId root) {
    261     if (access_only_ && root != start_)
    262       return false;
    263     ++nvisit_;
    264     return nvisit_ <= maxvisit_;
    265   }
    266 
    267   bool WhiteArc(StateId s, const Arc &arc) { return true; }
    268   bool GreyArc(StateId s, const Arc &arc) { return true; }
    269   bool BlackArc(StateId s, const Arc &arc) { return true; }
    270   void FinishState(StateId s) {}
    271   void FinishVisit() {}
    272 
    273  private:
    274   StateId maxvisit_;
    275   bool access_only_;
    276   StateId nvisit_;
    277   StateId start_;
    278 
    279 };
    280 
    281 
    282 }  // namespace fst
    283 
    284 #endif  // FST_LIB_VISIT_H__
    285