Home | History | Annotate | Download | only in fst
      1 // lookahead-matcher.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes to add lookahead to FST matchers, useful e.g. for improving
     20 // composition efficiency with certain inputs.
     21 
     22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
     23 #define FST_LIB_LOOKAHEAD_MATCHER_H__
     24 
     25 #include <fst/add-on.h>
     26 #include <fst/const-fst.h>
     27 #include <fst/fst.h>
     28 #include <fst/label-reachable.h>
     29 #include <fst/matcher.h>
     30 
     31 
     32 DECLARE_string(save_relabel_ipairs);
     33 DECLARE_string(save_relabel_opairs);
     34 
     35 namespace fst {
     36 
     37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see
     38 // matcher.h) and these additional methods:
     39 //
     40 // template <class F>
     41 // class LookAheadMatcher {
     42 //  public:
     43 //   typedef F FST;
     44 //   typedef F::Arc Arc;
     45 //   typedef typename Arc::StateId StateId;
     46 //   typedef typename Arc::Label Label;
     47 //   typedef typename Arc::Weight Weight;
     48 //
     49 //  // Required constructors.
     50 //  LookAheadMatcher(const F &fst, MatchType match_type);
     51 //   // If safe=true, the copy is thread-safe (except the lookahead Fst is
     52 //   // preserved). See Fst<>::Cop() for further doc.
     53 //  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
     54 //
     55 //  Below are methods for looking ahead for a match to a label and
     56 //  more generally, to a rational set. Each returns false if there is
     57 //  definitely not a match and returns true if there possibly is a
     58 //  match.
     59 
     60 //  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
     61 //  // after possibly following epsilon transitions?
     62 //  bool LookAheadLabel(Label label) const;
     63 //
     64 //  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
     65 //  // arbitrary rational set of strings, specified by an FST and a state
     66 //  // from which to begin the matching. If the lookahead FST is a
     67 //  // transducer, this looks on the side different from the matcher
     68 //  // 'match_type' (cf. composition).
     69 //
     70 //  // Are there paths P from 's' in the lookahead FST that can be read from
     71 //  // the cur. matcher state?
     72 //  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
     73 //
     74 //  // Gives an estimate of the combined weight of the paths P in the
     75 //  // lookahead and matcher FSTs for the last call to LookAheadFst.
     76 //  // A trivial implementation returns Weight::One(). Non-trivial
     77 //  // implementations are useful for weight-pushing in composition.
     78 //  Weight LookAheadWeight() const;
     79 //
     80 //  // Is there is a single non-epsilon arc found in the lookahead FST
     81 //  // that begins P (after possibly following any epsilons) in the last
     82 //  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
     83 //  // return false. A trivial implementation returns false. Non-trivial
     84 //  // implementations are useful for label-pushing in composition.
     85 //  bool LookAheadPrefix(Arc *arc);
     86 //
     87 //  // Optionally pre-specifies the lookahead FST that will be passed
     88 //  // to LookAheadFst() for possible precomputation. If copy is true,
     89 //  // then 'fst' is a copy of the FST used in the previous call to
     90 //  // this method (useful to avoid unnecessary updates).
     91 //  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
     92 //
     93 // };
     94 
     95 //
     96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
     97 //
     98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
     99 const uint32 kInputLookAheadMatcher =      0x00000001;
    100 
    101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
    102 const uint32 kOutputLookAheadMatcher =     0x00000002;
    103 
    104 // A non-trivial implementation of LookAheadWeight() method defined and
    105 // should be used?
    106 const uint32 kLookAheadWeight =            0x00000004;
    107 
    108 // A non-trivial implementation of LookAheadPrefix() method defined and
    109 // should be used?
    110 const uint32 kLookAheadPrefix =            0x00000008;
    111 
    112 // Look-ahead of matcher FST non-epsilon arcs?
    113 const uint32 kLookAheadNonEpsilons =       0x00000010;
    114 
    115 // Look-ahead of matcher FST epsilon arcs?
    116 const uint32 kLookAheadEpsilons =          0x00000020;
    117 
    118 // Ignore epsilon paths for the lookahead prefix? Note this gives
    119 // correct results in composition only with an appropriate composition
    120 // filter since it depends on the filter blocking the ignored paths.
    121 const uint32 kLookAheadNonEpsilonPrefix =  0x00000040;
    122 
    123 // For LabelLookAheadMatcher, save relabeling data to file
    124 const uint32 kLookAheadKeepRelabelData =  0x00000080;
    125 
    126 // Flags used for lookahead matchers.
    127 const uint32 kLookAheadFlags =            0x000000ff;
    128 
    129 // LookAhead Matcher interface, templated on the Arc definition; used
    130 // for lookahead matcher specializations that are returned by the
    131 // InitMatcher() Fst method.
    132 template <class A>
    133 class LookAheadMatcherBase : public MatcherBase<A> {
    134  public:
    135   typedef A Arc;
    136   typedef typename A::StateId StateId;
    137   typedef typename A::Label Label;
    138   typedef typename A::Weight Weight;
    139 
    140   LookAheadMatcherBase()
    141   : weight_(Weight::One()),
    142     prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
    143 
    144   virtual ~LookAheadMatcherBase() {}
    145 
    146   bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
    147 
    148   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    149     return LookAheadFst_(fst, s);
    150   }
    151 
    152   Weight LookAheadWeight() const { return weight_; }
    153 
    154   bool LookAheadPrefix(Arc *arc) const {
    155     if (prefix_arc_.nextstate != kNoStateId) {
    156       *arc = prefix_arc_;
    157       return true;
    158     } else {
    159       return false;
    160     }
    161   }
    162 
    163   virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
    164 
    165  protected:
    166   void SetLookAheadWeight(const Weight &w) { weight_ = w; }
    167 
    168   void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
    169 
    170   void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
    171 
    172  private:
    173   virtual bool LookAheadLabel_(Label label) const = 0;
    174   virtual bool LookAheadFst_(const Fst<Arc> &fst,
    175                              StateId s) = 0;  // This must set l.a. weight and
    176                                               // prefix if non-trivial.
    177   Weight weight_;                             // Look-ahead weight
    178   Arc prefix_arc_;                            // Look-ahead prefix arc
    179 };
    180 
    181 
    182 // Don't really lookahead, just declare future looks good regardless.
    183 template <class M>
    184 class TrivialLookAheadMatcher
    185     : public LookAheadMatcherBase<typename M::FST::Arc> {
    186  public:
    187   typedef typename M::FST FST;
    188   typedef typename M::Arc Arc;
    189   typedef typename Arc::StateId StateId;
    190   typedef typename Arc::Label Label;
    191   typedef typename Arc::Weight Weight;
    192 
    193   TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
    194       : matcher_(fst, match_type) {}
    195 
    196   TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
    197                           bool safe = false)
    198       : matcher_(lmatcher.matcher_, safe) {}
    199 
    200   // General matcher methods
    201   TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
    202     return new TrivialLookAheadMatcher<M>(*this, safe);
    203   }
    204 
    205   MatchType Type(bool test) const { return matcher_.Type(test); }
    206   void SetState(StateId s) { return matcher_.SetState(s); }
    207   bool Find(Label label) { return matcher_.Find(label); }
    208   bool Done() const { return matcher_.Done(); }
    209   const Arc& Value() const { return matcher_.Value(); }
    210   void Next() { matcher_.Next(); }
    211   virtual const FST &GetFst() const { return matcher_.GetFst(); }
    212   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
    213   uint32 Flags() const {
    214     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
    215   }
    216 
    217   // Look-ahead methods.
    218   bool LookAheadLabel(Label label) const { return true;  }
    219   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
    220   Weight LookAheadWeight() const { return Weight::One(); }
    221   bool LookAheadPrefix(Arc *arc) const { return false; }
    222   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
    223 
    224  private:
    225   // This allows base class virtual access to non-virtual derived-
    226   // class members of the same name. It makes the derived class more
    227   // efficient to use but unsafe to further derive.
    228   virtual void SetState_(StateId s) { SetState(s); }
    229   virtual bool Find_(Label label) { return Find(label); }
    230   virtual bool Done_() const { return Done(); }
    231   virtual const Arc& Value_() const { return Value(); }
    232   virtual void Next_() { Next(); }
    233 
    234   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    235 
    236   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    237     return LookAheadFst(fst, s);
    238   }
    239 
    240   Weight LookAheadWeight_() const { return LookAheadWeight(); }
    241   bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
    242 
    243   M matcher_;
    244 };
    245 
    246 // Look-ahead of one transition. Template argument F accepts flags to
    247 // control behavior.
    248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
    249           kLookAheadWeight | kLookAheadPrefix>
    250 class ArcLookAheadMatcher
    251     : public LookAheadMatcherBase<typename M::FST::Arc> {
    252  public:
    253   typedef typename M::FST FST;
    254   typedef typename M::Arc Arc;
    255   typedef typename Arc::StateId StateId;
    256   typedef typename Arc::Label Label;
    257   typedef typename Arc::Weight Weight;
    258   typedef NullAddOn MatcherData;
    259 
    260   using LookAheadMatcherBase<Arc>::LookAheadWeight;
    261   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
    262   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
    263   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
    264 
    265   ArcLookAheadMatcher(const FST &fst, MatchType match_type,
    266                       MatcherData *data = 0)
    267       : matcher_(fst, match_type),
    268         fst_(matcher_.GetFst()),
    269         lfst_(0),
    270         s_(kNoStateId) {}
    271 
    272   ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
    273                       bool safe = false)
    274       : matcher_(lmatcher.matcher_, safe),
    275         fst_(matcher_.GetFst()),
    276         lfst_(lmatcher.lfst_),
    277         s_(kNoStateId) {}
    278 
    279   // General matcher methods
    280   ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
    281     return new ArcLookAheadMatcher<M, F>(*this, safe);
    282   }
    283 
    284   MatchType Type(bool test) const { return matcher_.Type(test); }
    285 
    286   void SetState(StateId s) {
    287     s_ = s;
    288     matcher_.SetState(s);
    289   }
    290 
    291   bool Find(Label label) { return matcher_.Find(label); }
    292   bool Done() const { return matcher_.Done(); }
    293   const Arc& Value() const { return matcher_.Value(); }
    294   void Next() { matcher_.Next(); }
    295   const FST &GetFst() const { return fst_; }
    296   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
    297   uint32 Flags() const {
    298     return matcher_.Flags() | kInputLookAheadMatcher |
    299         kOutputLookAheadMatcher | F;
    300   }
    301 
    302   // Writable matcher methods
    303   MatcherData *GetData() const { return 0; }
    304 
    305   // Look-ahead methods.
    306   bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
    307 
    308   // Checks if there is a matching (possibly super-final) transition
    309   // at (s_, s).
    310   bool LookAheadFst(const Fst<Arc> &fst, StateId s);
    311 
    312   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    313     lfst_ = &fst;
    314   }
    315 
    316  private:
    317   // This allows base class virtual access to non-virtual derived-
    318   // class members of the same name. It makes the derived class more
    319   // efficient to use but unsafe to further derive.
    320   virtual void SetState_(StateId s) { SetState(s); }
    321   virtual bool Find_(Label label) { return Find(label); }
    322   virtual bool Done_() const { return Done(); }
    323   virtual const Arc& Value_() const { return Value(); }
    324   virtual void Next_() { Next(); }
    325 
    326   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    327   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    328     return LookAheadFst(fst, s);
    329   }
    330 
    331   mutable M matcher_;
    332   const FST &fst_;         // Matcher FST
    333   const Fst<Arc> *lfst_;   // Look-ahead FST
    334   StateId s_;              // Matcher state
    335 };
    336 
    337 template <class M, uint32 F>
    338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
    339   if (&fst != lfst_)
    340     InitLookAheadFst(fst);
    341 
    342   bool ret = false;
    343   ssize_t nprefix = 0;
    344   if (F & kLookAheadWeight)
    345     SetLookAheadWeight(Weight::Zero());
    346   if (F & kLookAheadPrefix)
    347     ClearLookAheadPrefix();
    348   if (fst_.Final(s_) != Weight::Zero() &&
    349       lfst_->Final(s) != Weight::Zero()) {
    350     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    351       return true;
    352     ++nprefix;
    353     if (F & kLookAheadWeight)
    354       SetLookAheadWeight(Plus(LookAheadWeight(),
    355                               Times(fst_.Final(s_), lfst_->Final(s))));
    356     ret = true;
    357   }
    358   if (matcher_.Find(kNoLabel)) {
    359     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    360       return true;
    361     ++nprefix;
    362     if (F & kLookAheadWeight)
    363       for (; !matcher_.Done(); matcher_.Next())
    364         SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
    365     ret = true;
    366   }
    367   for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
    368        !aiter.Done();
    369        aiter.Next()) {
    370     const Arc &arc = aiter.Value();
    371     Label label = kNoLabel;
    372     switch (matcher_.Type(false)) {
    373       case MATCH_INPUT:
    374         label = arc.olabel;
    375         break;
    376       case MATCH_OUTPUT:
    377         label = arc.ilabel;
    378         break;
    379       default:
    380         FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
    381         return true;
    382     }
    383     if (label == 0) {
    384       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    385         return true;
    386       if (!(F & kLookAheadNonEpsilonPrefix))
    387         ++nprefix;
    388       if (F & kLookAheadWeight)
    389         SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
    390       ret = true;
    391     } else if (matcher_.Find(label)) {
    392       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    393         return true;
    394       for (; !matcher_.Done(); matcher_.Next()) {
    395         ++nprefix;
    396         if (F & kLookAheadWeight)
    397           SetLookAheadWeight(Plus(LookAheadWeight(),
    398                                   Times(arc.weight,
    399                                         matcher_.Value().weight)));
    400         if ((F & kLookAheadPrefix) && nprefix == 1)
    401           SetLookAheadPrefix(arc);
    402       }
    403       ret = true;
    404     }
    405   }
    406   if (F & kLookAheadPrefix) {
    407     if (nprefix == 1)
    408       SetLookAheadWeight(Weight::One());  // Avoids double counting.
    409     else
    410       ClearLookAheadPrefix();
    411   }
    412   return ret;
    413 }
    414 
    415 
    416 // Template argument F accepts flags to control behavior.
    417 // It must include precisely one of KInputLookAheadMatcher or
    418 // KOutputLookAheadMatcher.
    419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
    420           kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
    421           kLookAheadKeepRelabelData,
    422           class S = DefaultAccumulator<typename M::Arc> >
    423 class LabelLookAheadMatcher
    424     : public LookAheadMatcherBase<typename M::FST::Arc> {
    425  public:
    426   typedef typename M::FST FST;
    427   typedef typename M::Arc Arc;
    428   typedef typename Arc::StateId StateId;
    429   typedef typename Arc::Label Label;
    430   typedef typename Arc::Weight Weight;
    431   typedef LabelReachableData<Label> MatcherData;
    432 
    433   using LookAheadMatcherBase<Arc>::LookAheadWeight;
    434   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
    435   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
    436   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
    437 
    438   LabelLookAheadMatcher(const FST &fst, MatchType match_type,
    439                         MatcherData *data = 0, S *s = 0)
    440       : matcher_(fst, match_type),
    441         lfst_(0),
    442         label_reachable_(0),
    443         s_(kNoStateId),
    444         error_(false) {
    445     if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
    446       FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
    447       error_ = true;
    448     }
    449     bool reach_input = match_type == MATCH_INPUT;
    450     if (data) {
    451       if (reach_input == data->ReachInput())
    452         label_reachable_ = new LabelReachable<Arc, S>(data, s);
    453     } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
    454                (!reach_input && (F & kOutputLookAheadMatcher))) {
    455       label_reachable_ = new LabelReachable<Arc, S>(
    456           fst, reach_input, s, F & kLookAheadKeepRelabelData);
    457     }
    458   }
    459 
    460   LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
    461                         bool safe = false)
    462       : matcher_(lmatcher.matcher_, safe),
    463         lfst_(lmatcher.lfst_),
    464         label_reachable_(
    465             lmatcher.label_reachable_ ?
    466             new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
    467         s_(kNoStateId),
    468         error_(lmatcher.error_) {}
    469 
    470   ~LabelLookAheadMatcher() {
    471     delete label_reachable_;
    472   }
    473 
    474   // General matcher methods
    475   LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
    476     return new LabelLookAheadMatcher<M, F, S>(*this, safe);
    477   }
    478 
    479   MatchType Type(bool test) const { return matcher_.Type(test); }
    480 
    481   void SetState(StateId s) {
    482     if (s_ == s)
    483       return;
    484     s_ = s;
    485     match_set_state_ = false;
    486     reach_set_state_ = false;
    487   }
    488 
    489   bool Find(Label label) {
    490     if (!match_set_state_) {
    491       matcher_.SetState(s_);
    492       match_set_state_ = true;
    493     }
    494     return matcher_.Find(label);
    495   }
    496 
    497   bool Done() const { return matcher_.Done(); }
    498   const Arc& Value() const { return matcher_.Value(); }
    499   void Next() { matcher_.Next(); }
    500   const FST &GetFst() const { return matcher_.GetFst(); }
    501 
    502   uint64 Properties(uint64 inprops) const {
    503     uint64 outprops = matcher_.Properties(inprops);
    504     if (error_ || (label_reachable_ && label_reachable_->Error()))
    505       outprops |= kError;
    506     return outprops;
    507   }
    508 
    509   uint32 Flags() const {
    510     if (label_reachable_ && label_reachable_->GetData()->ReachInput())
    511       return matcher_.Flags() | F | kInputLookAheadMatcher;
    512     else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
    513       return matcher_.Flags() | F | kOutputLookAheadMatcher;
    514     else
    515       return matcher_.Flags();
    516   }
    517 
    518   // Writable matcher methods
    519   MatcherData *GetData() const {
    520     return label_reachable_ ? label_reachable_->GetData() : 0;
    521   };
    522 
    523   // Look-ahead methods.
    524   bool LookAheadLabel(Label label) const {
    525     if (label == 0)
    526       return true;
    527 
    528     if (label_reachable_) {
    529       if (!reach_set_state_) {
    530         label_reachable_->SetState(s_);
    531         reach_set_state_ = true;
    532       }
    533       return label_reachable_->Reach(label);
    534     } else {
    535       return true;
    536     }
    537   }
    538 
    539   // Checks if there is a matching (possibly super-final) transition
    540   // at (s_, s).
    541   template <class L>
    542   bool LookAheadFst(const L &fst, StateId s);
    543 
    544   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    545     lfst_ = &fst;
    546     if (label_reachable_)
    547       label_reachable_->ReachInit(fst, copy);
    548   }
    549 
    550   template <class L>
    551   void InitLookAheadFst(const L& fst, bool copy = false) {
    552     lfst_ = static_cast<const Fst<Arc> *>(&fst);
    553     if (label_reachable_)
    554       label_reachable_->ReachInit(fst, copy);
    555   }
    556 
    557  private:
    558   // This allows base class virtual access to non-virtual derived-
    559   // class members of the same name. It makes the derived class more
    560   // efficient to use but unsafe to further derive.
    561   virtual void SetState_(StateId s) { SetState(s); }
    562   virtual bool Find_(Label label) { return Find(label); }
    563   virtual bool Done_() const { return Done(); }
    564   virtual const Arc& Value_() const { return Value(); }
    565   virtual void Next_() { Next(); }
    566 
    567   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    568   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    569     return LookAheadFst(fst, s);
    570   }
    571 
    572   mutable M matcher_;
    573   const Fst<Arc> *lfst_;                     // Look-ahead FST
    574   LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
    575   StateId s_;                                // Matcher state
    576   bool match_set_state_;                     // matcher_.SetState called?
    577   mutable bool reach_set_state_;             // reachable_.SetState called?
    578   bool error_;
    579 };
    580 
    581 template <class M, uint32 F, class S>
    582 template <class L> inline
    583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
    584   if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
    585     InitLookAheadFst(fst);
    586 
    587   SetLookAheadWeight(Weight::One());
    588   ClearLookAheadPrefix();
    589 
    590   if (!label_reachable_)
    591     return true;
    592 
    593   label_reachable_->SetState(s_, s);
    594   reach_set_state_ = true;
    595 
    596   bool compute_weight = F & kLookAheadWeight;
    597   bool compute_prefix = F & kLookAheadPrefix;
    598 
    599   bool reach_input = Type(false) == MATCH_OUTPUT;
    600   ArcIterator<L> aiter(fst, s);
    601   bool reach_arc = label_reachable_->Reach(&aiter, 0,
    602                                            internal::NumArcs(*lfst_, s),
    603                                            reach_input, compute_weight);
    604   if (reach_arc) {
    605     ssize_t begin = label_reachable_->ReachBegin();
    606     ssize_t end = label_reachable_->ReachEnd();
    607     if (compute_prefix && end - begin == 1) {
    608       aiter.Seek(begin);
    609       SetLookAheadPrefix(aiter.Value());
    610       compute_weight = false;
    611     } else if (compute_weight) {
    612       SetLookAheadWeight(label_reachable_->ReachWeight());
    613     }
    614   }
    615   Weight lfinal = internal::Final(*lfst_, s);
    616   bool reach_final = lfinal != Weight::Zero() &&
    617       label_reachable_->ReachFinal();
    618   if (reach_final && compute_weight)
    619     SetLookAheadWeight(reach_arc ?
    620                        Plus(LookAheadWeight(), lfinal) : lfinal);
    621 
    622   return reach_arc || reach_final;
    623 }
    624 
    625 
    626 // Label-lookahead relabeling class.
    627 template <class A>
    628 class LabelLookAheadRelabeler {
    629  public:
    630   typedef typename A::Label Label;
    631   typedef LabelReachableData<Label> MatcherData;
    632   typedef AddOnPair<MatcherData, MatcherData> D;
    633 
    634   // Relabels matcher Fst - initialization function object.
    635   template <typename I>
    636   LabelLookAheadRelabeler(I **impl);
    637 
    638   // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
    639   template <class L>
    640   static void Relabel(MutableFst<A> *fst, const L &mfst,
    641                       bool relabel_input) {
    642     typename L::Impl *impl = mfst.GetImpl();
    643     D *data = impl->GetAddOn();
    644     LabelReachable<A> reachable(data->First() ?
    645                                   data->First() : data->Second());
    646     reachable.Relabel(fst, relabel_input);
    647   }
    648 
    649   // Returns relabeling pairs (cf. relabel.h::Relabel()).
    650   // Class L should be a label-lookahead Fst.
    651   // If 'avoid_collisions' is true, extra pairs are added to
    652   // ensure no collisions when relabeling automata that have
    653   // labels unseen here.
    654   template <class L>
    655   static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
    656                            bool avoid_collisions = false) {
    657     typename L::Impl *impl = mfst.GetImpl();
    658     D *data = impl->GetAddOn();
    659     LabelReachable<A> reachable(data->First() ?
    660                                   data->First() : data->Second());
    661     reachable.RelabelPairs(pairs, avoid_collisions);
    662   }
    663 };
    664 
    665 template <class A>
    666 template <typename I> inline
    667 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
    668   Fst<A> &fst = (*impl)->GetFst();
    669   D *data = (*impl)->GetAddOn();
    670   const string name = (*impl)->Type();
    671   bool is_mutable = fst.Properties(kMutable, false);
    672   MutableFst<A> *mfst = 0;
    673   if (is_mutable) {
    674     mfst = static_cast<MutableFst<A> *>(&fst);
    675   } else {
    676     mfst = new VectorFst<A>(fst);
    677     data->IncrRefCount();
    678     delete *impl;
    679   }
    680   if (data->First()) {  // reach_input
    681     LabelReachable<A> reachable(data->First());
    682     reachable.Relabel(mfst, true);
    683     if (!FLAGS_save_relabel_ipairs.empty()) {
    684       vector<pair<Label, Label> > pairs;
    685       reachable.RelabelPairs(&pairs, true);
    686       WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
    687     }
    688   } else {
    689     LabelReachable<A> reachable(data->Second());
    690     reachable.Relabel(mfst, false);
    691     if (!FLAGS_save_relabel_opairs.empty()) {
    692       vector<pair<Label, Label> > pairs;
    693       reachable.RelabelPairs(&pairs, true);
    694       WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
    695     }
    696   }
    697   if (!is_mutable) {
    698     *impl = new I(*mfst, name);
    699     (*impl)->SetAddOn(data);
    700     delete mfst;
    701     data->DecrRefCount();
    702   }
    703 }
    704 
    705 
    706 // Generic lookahead matcher, templated on the FST definition
    707 // - a wrapper around pointer to specific one.
    708 template <class F>
    709 class LookAheadMatcher {
    710  public:
    711   typedef F FST;
    712   typedef typename F::Arc Arc;
    713   typedef typename Arc::StateId StateId;
    714   typedef typename Arc::Label Label;
    715   typedef typename Arc::Weight Weight;
    716   typedef LookAheadMatcherBase<Arc> LBase;
    717 
    718   LookAheadMatcher(const F &fst, MatchType match_type) {
    719     base_ = fst.InitMatcher(match_type);
    720     if (!base_)
    721       base_ = new SortedMatcher<F>(fst, match_type);
    722     lookahead_ = false;
    723   }
    724 
    725   LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
    726     base_ = matcher.base_->Copy(safe);
    727     lookahead_ = matcher.lookahead_;
    728   }
    729 
    730   ~LookAheadMatcher() { delete base_; }
    731 
    732   // General matcher methods
    733   LookAheadMatcher<F> *Copy(bool safe = false) const {
    734       return new LookAheadMatcher<F>(*this, safe);
    735   }
    736 
    737   MatchType Type(bool test) const { return base_->Type(test); }
    738   void SetState(StateId s) { base_->SetState(s); }
    739   bool Find(Label label) { return base_->Find(label); }
    740   bool Done() const { return base_->Done(); }
    741   const Arc& Value() const { return base_->Value(); }
    742   void Next() { base_->Next(); }
    743   const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
    744 
    745   uint64 Properties(uint64 props) const { return base_->Properties(props); }
    746 
    747   uint32 Flags() const { return base_->Flags(); }
    748 
    749   // Look-ahead methods
    750   bool LookAheadLabel(Label label) const {
    751     if (LookAheadCheck()) {
    752       LBase *lbase = static_cast<LBase *>(base_);
    753       return lbase->LookAheadLabel(label);
    754     } else {
    755       return true;
    756     }
    757   }
    758 
    759   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    760     if (LookAheadCheck()) {
    761       LBase *lbase = static_cast<LBase *>(base_);
    762       return lbase->LookAheadFst(fst, s);
    763     } else {
    764       return true;
    765     }
    766   }
    767 
    768   Weight LookAheadWeight() const {
    769     if (LookAheadCheck()) {
    770       LBase *lbase = static_cast<LBase *>(base_);
    771       return lbase->LookAheadWeight();
    772     } else {
    773       return Weight::One();
    774     }
    775   }
    776 
    777   bool LookAheadPrefix(Arc *arc) const {
    778     if (LookAheadCheck()) {
    779       LBase *lbase = static_cast<LBase *>(base_);
    780       return lbase->LookAheadPrefix(arc);
    781     } else {
    782       return false;
    783     }
    784   }
    785 
    786   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    787     if (LookAheadCheck()) {
    788       LBase *lbase = static_cast<LBase *>(base_);
    789       lbase->InitLookAheadFst(fst, copy);
    790     }
    791   }
    792 
    793  private:
    794   bool LookAheadCheck() const {
    795     if (!lookahead_) {
    796       lookahead_ = base_->Flags() &
    797           (kInputLookAheadMatcher | kOutputLookAheadMatcher);
    798       if (!lookahead_) {
    799         FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
    800       }
    801     }
    802     return lookahead_;
    803   }
    804 
    805   MatcherBase<Arc> *base_;
    806   mutable bool lookahead_;
    807 
    808   void operator=(const LookAheadMatcher<Arc> &);  // disallow
    809 };
    810 
    811 }  // namespace fst
    812 
    813 #endif  // FST_LIB_LOOKAHEAD_MATCHER_H__
    814