Home | History | Annotate | Download | only in fst
      1 // state-reachable.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Class to determine whether a given (final) state can be reached from some
     20 // other given state.
     21 
     22 #ifndef FST_LIB_STATE_REACHABLE_H__
     23 #define FST_LIB_STATE_REACHABLE_H__
     24 
     25 #include <vector>
     26 using std::vector;
     27 
     28 #include <fst/dfs-visit.h>
     29 #include <fst/fst.h>
     30 #include <fst/interval-set.h>
     31 
     32 
     33 namespace fst {
     34 
     35 // Computes the (final) states reachable from a given state in an FST.
     36 // After this visitor has been called, a final state f can be reached
     37 // from a state s iff (*isets)[s].Member(state2index[f]) is true, where
     38 // (*isets[s]) is a set of half-open inteval of final state indices
     39 // and state2index[f] maps from a final state to its index.
     40 //
     41 // If state2index is empty, it is filled-in with suitable indices.
     42 // If it is non-empty, those indices are used; in this case, the
     43 // final states must have out-degree 0.
     44 template <class A, typename I = typename A::StateId>
     45 class IntervalReachVisitor {
     46  public:
     47   typedef typename A::StateId StateId;
     48   typedef typename A::Label Label;
     49   typedef typename A::Weight Weight;
     50   typedef typename IntervalSet<I>::Interval Interval;
     51 
     52   IntervalReachVisitor(const Fst<A> &fst,
     53                        vector< IntervalSet<I> > *isets,
     54                        vector<I> *state2index)
     55       : fst_(fst),
     56         isets_(isets),
     57         state2index_(state2index),
     58         index_(state2index->empty() ? 1 : -1),
     59         error_(false) {
     60     isets_->clear();
     61   }
     62 
     63   void InitVisit(const Fst<A> &fst) { error_ = false; }
     64 
     65   bool InitState(StateId s, StateId r) {
     66     while (isets_->size() <= s)
     67       isets_->push_back(IntervalSet<Label>());
     68     while (state2index_->size() <= s)
     69       state2index_->push_back(-1);
     70 
     71     if (fst_.Final(s) != Weight::Zero()) {
     72       // Create tree interval
     73       vector<Interval> *intervals = (*isets_)[s].Intervals();
     74       if (index_ < 0) {  // Use state2index_ map to set index
     75         if (fst_.NumArcs(s) > 0) {
     76           FSTERROR() << "IntervalReachVisitor: state2index map must be empty "
     77                      << "for this FST";
     78           error_ = true;
     79           return false;
     80         }
     81         I index = (*state2index_)[s];
     82         if (index < 0) {
     83           FSTERROR() << "IntervalReachVisitor: state2index map incomplete";
     84           error_ = true;
     85           return false;
     86         }
     87         intervals->push_back(Interval(index, index + 1));
     88       } else {           // Use pre-order index
     89         intervals->push_back(Interval(index_, index_ + 1));
     90         (*state2index_)[s] = index_++;
     91       }
     92     }
     93     return true;
     94   }
     95 
     96   bool TreeArc(StateId s, const A &arc) {
     97     return true;
     98   }
     99 
    100   bool BackArc(StateId s, const A &arc) {
    101     FSTERROR() << "IntervalReachVisitor: cyclic input";
    102     error_ = true;
    103     return false;
    104   }
    105 
    106   bool ForwardOrCrossArc(StateId s, const A &arc) {
    107     // Non-tree interval
    108     (*isets_)[s].Union((*isets_)[arc.nextstate]);
    109     return true;
    110   }
    111 
    112   void FinishState(StateId s, StateId p, const A *arc) {
    113     if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) {
    114       vector<Interval> *intervals = (*isets_)[s].Intervals();
    115       (*intervals)[0].end = index_;      // Update tree interval end
    116     }
    117     (*isets_)[s].Normalize();
    118     if (p != kNoStateId)
    119       (*isets_)[p].Union((*isets_)[s]);  // Propagate intervals to parent
    120   }
    121 
    122   void FinishVisit() {}
    123 
    124   bool Error() const { return error_; }
    125 
    126  private:
    127   const Fst<A> &fst_;
    128   vector< IntervalSet<I> > *isets_;
    129   vector<I> *state2index_;
    130   I index_;
    131   bool error_;
    132 };
    133 
    134 
    135 // Tests reachability of final states from a given state. To test for
    136 // reachability from a state s, first do SetState(s). Then a final
    137 // state f can be reached from state s of FST iff Reach(f) is true.
    138 template <class A, typename I = typename A::StateId>
    139 class StateReachable {
    140  public:
    141   typedef A Arc;
    142   typedef I Index;
    143   typedef typename A::StateId StateId;
    144   typedef typename A::Label Label;
    145   typedef typename A::Weight Weight;
    146   typedef typename IntervalSet<I>::Interval Interval;
    147 
    148   StateReachable(const Fst<A> &fst)
    149       : error_(false) {
    150     IntervalReachVisitor<Arc> reach_visitor(fst, &isets_, &state2index_);
    151     DfsVisit(fst, &reach_visitor);
    152     if (reach_visitor.Error()) error_ = true;
    153   }
    154 
    155   StateReachable(const StateReachable<A> &reachable) {
    156     FSTERROR() << "Copy constructor for state reachable class "
    157                << "not yet implemented.";
    158     error_ = true;
    159   }
    160 
    161   // Set current state.
    162   void SetState(StateId s) { s_ = s; }
    163 
    164   // Can reach this label from current state?
    165   bool Reach(StateId s) {
    166     if (s >= state2index_.size())
    167       return false;
    168 
    169     I i =  state2index_[s];
    170     if (i < 0) {
    171       FSTERROR() << "StateReachable: state non-final: " << s;
    172       error_ = true;
    173       return false;
    174     }
    175     return isets_[s_].Member(i);
    176   }
    177 
    178   // Access to the state-to-index mapping. Unassigned states have index -1.
    179   vector<I> &State2Index() { return state2index_; }
    180 
    181   // Access to the interval sets. These specify the reachability
    182   // to the final states as intervals of the final state indices.
    183   const vector< IntervalSet<I> > &IntervalSets() { return isets_; }
    184 
    185   bool Error() const { return error_; }
    186 
    187  private:
    188   StateId s_;                                 // Current state
    189   vector< IntervalSet<I> > isets_;            // Interval sets per state
    190   vector<I> state2index_;                     // Finds index for a final state
    191   bool error_;
    192 
    193   void operator=(const StateReachable<A> &);  // Disallow
    194 };
    195 
    196 }  // namespace fst
    197 
    198 #endif  // FST_LIB_STATE_REACHABLE_H__
    199