Home | History | Annotate | Download | only in fst
      1 // lookahead-matcher.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes to add lookahead to FST matchers, useful e.g. for improving
     20 // composition efficiency with certain inputs.
     21 
     22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
     23 #define FST_LIB_LOOKAHEAD_MATCHER_H__
     24 
     25 #include <fst/add-on.h>
     26 #include <fst/const-fst.h>
     27 #include <fst/fst.h>
     28 #include <fst/label-reachable.h>
     29 #include <fst/matcher.h>
     30 
     31 
     32 DECLARE_string(save_relabel_ipairs);
     33 DECLARE_string(save_relabel_opairs);
     34 
     35 namespace fst {
     36 
     37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see
     38 // matcher.h) and these additional methods:
     39 //
     40 // template <class F>
     41 // class LookAheadMatcher {
     42 //  public:
     43 //   typedef F FST;
     44 //   typedef F::Arc Arc;
     45 //   typedef typename Arc::StateId StateId;
     46 //   typedef typename Arc::Label Label;
     47 //   typedef typename Arc::Weight Weight;
     48 //
     49 //  // Required constructors.
     50 //  LookAheadMatcher(const F &fst, MatchType match_type);
     51 //   // If safe=true, the copy is thread-safe (except the lookahead Fst is
     52 //   // preserved). See Fst<>::Cop() for further doc.
     53 //  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
     54 //
     55 //  Below are methods for looking ahead for a match to a label and
     56 //  more generally, to a rational set. Each returns false if there is
     57 //  definitely not a match and returns true if there possibly is a
     58 //  match.
     59 
     60 //  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
     61 //  // after possibly following epsilon transitions?
     62 //  bool LookAheadLabel(Label label) const;
     63 //
     64 //  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
     65 //  // arbitrary rational set of strings, specified by an FST and a state
     66 //  // from which to begin the matching. If the lookahead FST is a
     67 //  // transducer, this looks on the side different from the matcher
     68 //  // 'match_type' (cf. composition).
     69 //
     70 //  // Are there paths P from 's' in the lookahead FST that can be read from
     71 //  // the cur. matcher state?
     72 //  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
     73 //
     74 //  // Gives an estimate of the combined weight of the paths P in the
     75 //  // lookahead and matcher FSTs for the last call to LookAheadFst.
     76 //  // A trivial implementation returns Weight::One(). Non-trivial
     77 //  // implementations are useful for weight-pushing in composition.
     78 //  Weight LookAheadWeight() const;
     79 //
     80 //  // Is there is a single non-epsilon arc found in the lookahead FST
     81 //  // that begins P (after possibly following any epsilons) in the last
     82 //  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
     83 //  // return false. A trivial implementation returns false. Non-trivial
     84 //  // implementations are useful for label-pushing in composition.
     85 //  bool LookAheadPrefix(Arc *arc);
     86 //
     87 //  // Optionally pre-specifies the lookahead FST that will be passed
     88 //  // to LookAheadFst() for possible precomputation. If copy is true,
     89 //  // then 'fst' is a copy of the FST used in the previous call to
     90 //  // this method (useful to avoid unnecessary updates).
     91 //  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
     92 //
     93 // };
     94 
     95 //
     96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
     97 //
     98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
     99 const uint32 kInputLookAheadMatcher =     0x00000010;
    100 
    101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
    102 const uint32 kOutputLookAheadMatcher =    0x00000020;
    103 
    104 // A non-trivial implementation of LookAheadWeight() method defined and
    105 // should be used?
    106 const uint32 kLookAheadWeight =           0x00000040;
    107 
    108 // A non-trivial implementation of LookAheadPrefix() method defined and
    109 // should be used?
    110 const uint32 kLookAheadPrefix =           0x00000080;
    111 
    112 // Look-ahead of matcher FST non-epsilon arcs?
    113 const uint32 kLookAheadNonEpsilons =      0x00000100;
    114 
    115 // Look-ahead of matcher FST epsilon arcs?
    116 const uint32 kLookAheadEpsilons =         0x00000200;
    117 
    118 // Ignore epsilon paths for the lookahead prefix? Note this gives
    119 // correct results in composition only with an appropriate composition
    120 // filter since it depends on the filter blocking the ignored paths.
    121 const uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
    122 
    123 // For LabelLookAheadMatcher, save relabeling data to file
    124 const uint32 kLookAheadKeepRelabelData =  0x00000800;
    125 
    126 // Flags used for lookahead matchers.
    127 const uint32 kLookAheadFlags =            0x00000ff0;
    128 
    129 // LookAhead Matcher interface, templated on the Arc definition; used
    130 // for lookahead matcher specializations that are returned by the
    131 // InitMatcher() Fst method.
    132 template <class A>
    133 class LookAheadMatcherBase : public MatcherBase<A> {
    134  public:
    135   typedef A Arc;
    136   typedef typename A::StateId StateId;
    137   typedef typename A::Label Label;
    138   typedef typename A::Weight Weight;
    139 
    140   LookAheadMatcherBase()
    141   : weight_(Weight::One()),
    142     prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
    143 
    144   virtual ~LookAheadMatcherBase() {}
    145 
    146   bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
    147 
    148   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    149     return LookAheadFst_(fst, s);
    150   }
    151 
    152   Weight LookAheadWeight() const { return weight_; }
    153 
    154   bool LookAheadPrefix(Arc *arc) const {
    155     if (prefix_arc_.nextstate != kNoStateId) {
    156       *arc = prefix_arc_;
    157       return true;
    158     } else {
    159       return false;
    160     }
    161   }
    162 
    163   virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
    164 
    165  protected:
    166   void SetLookAheadWeight(const Weight &w) { weight_ = w; }
    167 
    168   void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
    169 
    170   void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
    171 
    172  private:
    173   virtual bool LookAheadLabel_(Label label) const = 0;
    174   virtual bool LookAheadFst_(const Fst<Arc> &fst,
    175                              StateId s) = 0;  // This must set l.a. weight and
    176                                               // prefix if non-trivial.
    177   Weight weight_;                             // Look-ahead weight
    178   Arc prefix_arc_;                            // Look-ahead prefix arc
    179 };
    180 
    181 
    182 // Don't really lookahead, just declare future looks good regardless.
    183 template <class M>
    184 class TrivialLookAheadMatcher
    185     : public LookAheadMatcherBase<typename M::FST::Arc> {
    186  public:
    187   typedef typename M::FST FST;
    188   typedef typename M::Arc Arc;
    189   typedef typename Arc::StateId StateId;
    190   typedef typename Arc::Label Label;
    191   typedef typename Arc::Weight Weight;
    192 
    193   TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
    194       : matcher_(fst, match_type) {}
    195 
    196   TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
    197                           bool safe = false)
    198       : matcher_(lmatcher.matcher_, safe) {}
    199 
    200   // General matcher methods
    201   TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
    202     return new TrivialLookAheadMatcher<M>(*this, safe);
    203   }
    204 
    205   MatchType Type(bool test) const { return matcher_.Type(test); }
    206   void SetState(StateId s) { return matcher_.SetState(s); }
    207   bool Find(Label label) { return matcher_.Find(label); }
    208   bool Done() const { return matcher_.Done(); }
    209   const Arc& Value() const { return matcher_.Value(); }
    210   void Next() { matcher_.Next(); }
    211   virtual const FST &GetFst() const { return matcher_.GetFst(); }
    212   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
    213   uint32 Flags() const {
    214     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
    215   }
    216 
    217   // Look-ahead methods.
    218   bool LookAheadLabel(Label label) const { return true;  }
    219   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
    220   Weight LookAheadWeight() const { return Weight::One(); }
    221   bool LookAheadPrefix(Arc *arc) const { return false; }
    222   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
    223 
    224  private:
    225   // This allows base class virtual access to non-virtual derived-
    226   // class members of the same name. It makes the derived class more
    227   // efficient to use but unsafe to further derive.
    228   virtual void SetState_(StateId s) { SetState(s); }
    229   virtual bool Find_(Label label) { return Find(label); }
    230   virtual bool Done_() const { return Done(); }
    231   virtual const Arc& Value_() const { return Value(); }
    232   virtual void Next_() { Next(); }
    233 
    234   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    235 
    236   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    237     return LookAheadFst(fst, s);
    238   }
    239 
    240   Weight LookAheadWeight_() const { return LookAheadWeight(); }
    241   bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
    242 
    243   M matcher_;
    244 };
    245 
    246 // Look-ahead of one transition. Template argument F accepts flags to
    247 // control behavior.
    248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
    249           kLookAheadWeight | kLookAheadPrefix>
    250 class ArcLookAheadMatcher
    251     : public LookAheadMatcherBase<typename M::FST::Arc> {
    252  public:
    253   typedef typename M::FST FST;
    254   typedef typename M::Arc Arc;
    255   typedef typename Arc::StateId StateId;
    256   typedef typename Arc::Label Label;
    257   typedef typename Arc::Weight Weight;
    258   typedef NullAddOn MatcherData;
    259 
    260   using LookAheadMatcherBase<Arc>::LookAheadWeight;
    261   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
    262   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
    263   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
    264 
    265   ArcLookAheadMatcher(const FST &fst, MatchType match_type,
    266                       MatcherData *data = 0)
    267       : matcher_(fst, match_type),
    268         fst_(matcher_.GetFst()),
    269         lfst_(0),
    270         s_(kNoStateId) {}
    271 
    272   ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
    273                       bool safe = false)
    274       : matcher_(lmatcher.matcher_, safe),
    275         fst_(matcher_.GetFst()),
    276         lfst_(lmatcher.lfst_),
    277         s_(kNoStateId) {}
    278 
    279   // General matcher methods
    280   ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
    281     return new ArcLookAheadMatcher<M, F>(*this, safe);
    282   }
    283 
    284   MatchType Type(bool test) const { return matcher_.Type(test); }
    285 
    286   void SetState(StateId s) {
    287     s_ = s;
    288     matcher_.SetState(s);
    289   }
    290 
    291   bool Find(Label label) { return matcher_.Find(label); }
    292   bool Done() const { return matcher_.Done(); }
    293   const Arc& Value() const { return matcher_.Value(); }
    294   void Next() { matcher_.Next(); }
    295   const FST &GetFst() const { return fst_; }
    296   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
    297   uint32 Flags() const {
    298     return matcher_.Flags() | kInputLookAheadMatcher |
    299         kOutputLookAheadMatcher | F;
    300   }
    301 
    302   // Writable matcher methods
    303   MatcherData *GetData() const { return 0; }
    304 
    305   // Look-ahead methods.
    306   bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
    307 
    308   // Checks if there is a matching (possibly super-final) transition
    309   // at (s_, s).
    310   bool LookAheadFst(const Fst<Arc> &fst, StateId s);
    311 
    312   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    313     lfst_ = &fst;
    314   }
    315 
    316  private:
    317   // This allows base class virtual access to non-virtual derived-
    318   // class members of the same name. It makes the derived class more
    319   // efficient to use but unsafe to further derive.
    320   virtual void SetState_(StateId s) { SetState(s); }
    321   virtual bool Find_(Label label) { return Find(label); }
    322   virtual bool Done_() const { return Done(); }
    323   virtual const Arc& Value_() const { return Value(); }
    324   virtual void Next_() { Next(); }
    325 
    326   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    327   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    328     return LookAheadFst(fst, s);
    329   }
    330 
    331   mutable M matcher_;
    332   const FST &fst_;         // Matcher FST
    333   const Fst<Arc> *lfst_;   // Look-ahead FST
    334   StateId s_;              // Matcher state
    335 };
    336 
    337 template <class M, uint32 F>
    338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
    339   if (&fst != lfst_)
    340     InitLookAheadFst(fst);
    341 
    342   bool ret = false;
    343   ssize_t nprefix = 0;
    344   if (F & kLookAheadWeight)
    345     SetLookAheadWeight(Weight::Zero());
    346   if (F & kLookAheadPrefix)
    347     ClearLookAheadPrefix();
    348   if (fst_.Final(s_) != Weight::Zero() &&
    349       lfst_->Final(s) != Weight::Zero()) {
    350     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    351       return true;
    352     ++nprefix;
    353     if (F & kLookAheadWeight)
    354       SetLookAheadWeight(Plus(LookAheadWeight(),
    355                               Times(fst_.Final(s_), lfst_->Final(s))));
    356     ret = true;
    357   }
    358   if (matcher_.Find(kNoLabel)) {
    359     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    360       return true;
    361     ++nprefix;
    362     if (F & kLookAheadWeight)
    363       for (; !matcher_.Done(); matcher_.Next())
    364         SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
    365     ret = true;
    366   }
    367   for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
    368        !aiter.Done();
    369        aiter.Next()) {
    370     const Arc &arc = aiter.Value();
    371     Label label = kNoLabel;
    372     switch (matcher_.Type(false)) {
    373       case MATCH_INPUT:
    374         label = arc.olabel;
    375         break;
    376       case MATCH_OUTPUT:
    377         label = arc.ilabel;
    378         break;
    379       default:
    380         FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
    381         return true;
    382     }
    383     if (label == 0) {
    384       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    385         return true;
    386       if (!(F & kLookAheadNonEpsilonPrefix))
    387         ++nprefix;
    388       if (F & kLookAheadWeight)
    389         SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
    390       ret = true;
    391     } else if (matcher_.Find(label)) {
    392       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
    393         return true;
    394       for (; !matcher_.Done(); matcher_.Next()) {
    395         ++nprefix;
    396         if (F & kLookAheadWeight)
    397           SetLookAheadWeight(Plus(LookAheadWeight(),
    398                                   Times(arc.weight,
    399                                         matcher_.Value().weight)));
    400         if ((F & kLookAheadPrefix) && nprefix == 1)
    401           SetLookAheadPrefix(arc);
    402       }
    403       ret = true;
    404     }
    405   }
    406   if (F & kLookAheadPrefix) {
    407     if (nprefix == 1)
    408       SetLookAheadWeight(Weight::One());  // Avoids double counting.
    409     else
    410       ClearLookAheadPrefix();
    411   }
    412   return ret;
    413 }
    414 
    415 
    416 // Template argument F accepts flags to control behavior.
    417 // It must include precisely one of KInputLookAheadMatcher or
    418 // KOutputLookAheadMatcher.
    419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
    420           kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
    421           kLookAheadKeepRelabelData,
    422           class S = DefaultAccumulator<typename M::Arc> >
    423 class LabelLookAheadMatcher
    424     : public LookAheadMatcherBase<typename M::FST::Arc> {
    425  public:
    426   typedef typename M::FST FST;
    427   typedef typename M::Arc Arc;
    428   typedef typename Arc::StateId StateId;
    429   typedef typename Arc::Label Label;
    430   typedef typename Arc::Weight Weight;
    431   typedef LabelReachableData<Label> MatcherData;
    432 
    433   using LookAheadMatcherBase<Arc>::LookAheadWeight;
    434   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
    435   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
    436   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
    437 
    438   LabelLookAheadMatcher(const FST &fst, MatchType match_type,
    439                         MatcherData *data = 0, S *s = 0)
    440       : matcher_(fst, match_type),
    441         lfst_(0),
    442         label_reachable_(0),
    443         s_(kNoStateId),
    444         error_(false) {
    445     if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
    446       FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
    447       error_ = true;
    448     }
    449     bool reach_input = match_type == MATCH_INPUT;
    450     if (data) {
    451       if (reach_input == data->ReachInput())
    452         label_reachable_ = new LabelReachable<Arc, S>(data, s);
    453     } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
    454                (!reach_input && (F & kOutputLookAheadMatcher))) {
    455       label_reachable_ = new LabelReachable<Arc, S>(
    456           fst, reach_input, s, F & kLookAheadKeepRelabelData);
    457     }
    458   }
    459 
    460   LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
    461                         bool safe = false)
    462       : matcher_(lmatcher.matcher_, safe),
    463         lfst_(lmatcher.lfst_),
    464         label_reachable_(
    465             lmatcher.label_reachable_ ?
    466             new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
    467         s_(kNoStateId),
    468         error_(lmatcher.error_) {}
    469 
    470   ~LabelLookAheadMatcher() {
    471     delete label_reachable_;
    472   }
    473 
    474   // General matcher methods
    475   LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
    476     return new LabelLookAheadMatcher<M, F, S>(*this, safe);
    477   }
    478 
    479   MatchType Type(bool test) const { return matcher_.Type(test); }
    480 
    481   void SetState(StateId s) {
    482     if (s_ == s)
    483       return;
    484     s_ = s;
    485     match_set_state_ = false;
    486     reach_set_state_ = false;
    487   }
    488 
    489   bool Find(Label label) {
    490     if (!match_set_state_) {
    491       matcher_.SetState(s_);
    492       match_set_state_ = true;
    493     }
    494     return matcher_.Find(label);
    495   }
    496 
    497   bool Done() const { return matcher_.Done(); }
    498   const Arc& Value() const { return matcher_.Value(); }
    499   void Next() { matcher_.Next(); }
    500   const FST &GetFst() const { return matcher_.GetFst(); }
    501 
    502   uint64 Properties(uint64 inprops) const {
    503     uint64 outprops = matcher_.Properties(inprops);
    504     if (error_ || (label_reachable_ && label_reachable_->Error()))
    505       outprops |= kError;
    506     return outprops;
    507   }
    508 
    509   uint32 Flags() const {
    510     if (label_reachable_ && label_reachable_->GetData()->ReachInput())
    511       return matcher_.Flags() | F | kInputLookAheadMatcher;
    512     else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
    513       return matcher_.Flags() | F | kOutputLookAheadMatcher;
    514     else
    515       return matcher_.Flags();
    516   }
    517 
    518   // Writable matcher methods
    519   MatcherData *GetData() const {
    520     return label_reachable_ ? label_reachable_->GetData() : 0;
    521   };
    522 
    523   // Look-ahead methods.
    524   bool LookAheadLabel(Label label) const {
    525     if (label == 0)
    526       return true;
    527 
    528     if (label_reachable_) {
    529       if (!reach_set_state_) {
    530         label_reachable_->SetState(s_);
    531         reach_set_state_ = true;
    532       }
    533       return label_reachable_->Reach(label);
    534     } else {
    535       return true;
    536     }
    537   }
    538 
    539   // Checks if there is a matching (possibly super-final) transition
    540   // at (s_, s).
    541   template <class L>
    542   bool LookAheadFst(const L &fst, StateId s);
    543 
    544   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    545     lfst_ = &fst;
    546     if (label_reachable_)
    547       label_reachable_->ReachInit(fst, copy);
    548   }
    549 
    550   template <class L>
    551   void InitLookAheadFst(const L& fst, bool copy = false) {
    552     lfst_ = static_cast<const Fst<Arc> *>(&fst);
    553     if (label_reachable_)
    554       label_reachable_->ReachInit(fst, copy);
    555   }
    556 
    557  private:
    558   // This allows base class virtual access to non-virtual derived-
    559   // class members of the same name. It makes the derived class more
    560   // efficient to use but unsafe to further derive.
    561   virtual void SetState_(StateId s) { SetState(s); }
    562   virtual bool Find_(Label label) { return Find(label); }
    563   virtual bool Done_() const { return Done(); }
    564   virtual const Arc& Value_() const { return Value(); }
    565   virtual void Next_() { Next(); }
    566 
    567   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
    568   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
    569     return LookAheadFst(fst, s);
    570   }
    571 
    572   mutable M matcher_;
    573   const Fst<Arc> *lfst_;                     // Look-ahead FST
    574   LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
    575   StateId s_;                                // Matcher state
    576   bool match_set_state_;                     // matcher_.SetState called?
    577   mutable bool reach_set_state_;             // reachable_.SetState called?
    578   bool error_;
    579 };
    580 
    581 template <class M, uint32 F, class S>
    582 template <class L> inline
    583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
    584   if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
    585     InitLookAheadFst(fst);
    586 
    587   SetLookAheadWeight(Weight::One());
    588   ClearLookAheadPrefix();
    589 
    590   if (!label_reachable_)
    591     return true;
    592 
    593   label_reachable_->SetState(s_, s);
    594   reach_set_state_ = true;
    595 
    596   bool compute_weight = F & kLookAheadWeight;
    597   bool compute_prefix = F & kLookAheadPrefix;
    598 
    599   bool reach_input = Type(false) == MATCH_OUTPUT;
    600   ArcIterator<L> aiter(fst, s);
    601   bool reach_arc = label_reachable_->Reach(&aiter, 0,
    602                                            internal::NumArcs(*lfst_, s),
    603                                            reach_input, compute_weight);
    604   Weight lfinal = internal::Final(*lfst_, s);
    605   bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal();
    606   if (reach_arc) {
    607     ssize_t begin = label_reachable_->ReachBegin();
    608     ssize_t end = label_reachable_->ReachEnd();
    609     if (compute_prefix && end - begin == 1 && !reach_final) {
    610       aiter.Seek(begin);
    611       SetLookAheadPrefix(aiter.Value());
    612       compute_weight = false;
    613     } else if (compute_weight) {
    614       SetLookAheadWeight(label_reachable_->ReachWeight());
    615     }
    616   }
    617   if (reach_final && compute_weight)
    618     SetLookAheadWeight(reach_arc ?
    619                        Plus(LookAheadWeight(), lfinal) : lfinal);
    620 
    621   return reach_arc || reach_final;
    622 }
    623 
    624 
    625 // Label-lookahead relabeling class.
    626 template <class A>
    627 class LabelLookAheadRelabeler {
    628  public:
    629   typedef typename A::Label Label;
    630   typedef LabelReachableData<Label> MatcherData;
    631   typedef AddOnPair<MatcherData, MatcherData> D;
    632 
    633   // Relabels matcher Fst - initialization function object.
    634   template <typename I>
    635   LabelLookAheadRelabeler(I **impl);
    636 
    637   // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
    638   template <class L>
    639   static void Relabel(MutableFst<A> *fst, const L &mfst,
    640                       bool relabel_input) {
    641     typename L::Impl *impl = mfst.GetImpl();
    642     D *data = impl->GetAddOn();
    643     LabelReachable<A> reachable(data->First() ?
    644                                   data->First() : data->Second());
    645     reachable.Relabel(fst, relabel_input);
    646   }
    647 
    648   // Returns relabeling pairs (cf. relabel.h::Relabel()).
    649   // Class L should be a label-lookahead Fst.
    650   // If 'avoid_collisions' is true, extra pairs are added to
    651   // ensure no collisions when relabeling automata that have
    652   // labels unseen here.
    653   template <class L>
    654   static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
    655                            bool avoid_collisions = false) {
    656     typename L::Impl *impl = mfst.GetImpl();
    657     D *data = impl->GetAddOn();
    658     LabelReachable<A> reachable(data->First() ?
    659                                   data->First() : data->Second());
    660     reachable.RelabelPairs(pairs, avoid_collisions);
    661   }
    662 };
    663 
    664 template <class A>
    665 template <typename I> inline
    666 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
    667   Fst<A> &fst = (*impl)->GetFst();
    668   D *data = (*impl)->GetAddOn();
    669   const string name = (*impl)->Type();
    670   bool is_mutable = fst.Properties(kMutable, false);
    671   MutableFst<A> *mfst = 0;
    672   if (is_mutable) {
    673     mfst = static_cast<MutableFst<A> *>(&fst);
    674   } else {
    675     mfst = new VectorFst<A>(fst);
    676     data->IncrRefCount();
    677     delete *impl;
    678   }
    679   if (data->First()) {  // reach_input
    680     LabelReachable<A> reachable(data->First());
    681     reachable.Relabel(mfst, true);
    682     if (!FLAGS_save_relabel_ipairs.empty()) {
    683       vector<pair<Label, Label> > pairs;
    684       reachable.RelabelPairs(&pairs, true);
    685       WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
    686     }
    687   } else {
    688     LabelReachable<A> reachable(data->Second());
    689     reachable.Relabel(mfst, false);
    690     if (!FLAGS_save_relabel_opairs.empty()) {
    691       vector<pair<Label, Label> > pairs;
    692       reachable.RelabelPairs(&pairs, true);
    693       WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
    694     }
    695   }
    696   if (!is_mutable) {
    697     *impl = new I(*mfst, name);
    698     (*impl)->SetAddOn(data);
    699     delete mfst;
    700     data->DecrRefCount();
    701   }
    702 }
    703 
    704 
    705 // Generic lookahead matcher, templated on the FST definition
    706 // - a wrapper around pointer to specific one.
    707 template <class F>
    708 class LookAheadMatcher {
    709  public:
    710   typedef F FST;
    711   typedef typename F::Arc Arc;
    712   typedef typename Arc::StateId StateId;
    713   typedef typename Arc::Label Label;
    714   typedef typename Arc::Weight Weight;
    715   typedef LookAheadMatcherBase<Arc> LBase;
    716 
    717   LookAheadMatcher(const F &fst, MatchType match_type) {
    718     base_ = fst.InitMatcher(match_type);
    719     if (!base_)
    720       base_ = new SortedMatcher<F>(fst, match_type);
    721     lookahead_ = false;
    722   }
    723 
    724   LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
    725     base_ = matcher.base_->Copy(safe);
    726     lookahead_ = matcher.lookahead_;
    727   }
    728 
    729   ~LookAheadMatcher() { delete base_; }
    730 
    731   // General matcher methods
    732   LookAheadMatcher<F> *Copy(bool safe = false) const {
    733       return new LookAheadMatcher<F>(*this, safe);
    734   }
    735 
    736   MatchType Type(bool test) const { return base_->Type(test); }
    737   void SetState(StateId s) { base_->SetState(s); }
    738   bool Find(Label label) { return base_->Find(label); }
    739   bool Done() const { return base_->Done(); }
    740   const Arc& Value() const { return base_->Value(); }
    741   void Next() { base_->Next(); }
    742   const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
    743 
    744   uint64 Properties(uint64 props) const { return base_->Properties(props); }
    745 
    746   uint32 Flags() const { return base_->Flags(); }
    747 
    748   // Look-ahead methods
    749   bool LookAheadLabel(Label label) const {
    750     if (LookAheadCheck()) {
    751       LBase *lbase = static_cast<LBase *>(base_);
    752       return lbase->LookAheadLabel(label);
    753     } else {
    754       return true;
    755     }
    756   }
    757 
    758   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
    759     if (LookAheadCheck()) {
    760       LBase *lbase = static_cast<LBase *>(base_);
    761       return lbase->LookAheadFst(fst, s);
    762     } else {
    763       return true;
    764     }
    765   }
    766 
    767   Weight LookAheadWeight() const {
    768     if (LookAheadCheck()) {
    769       LBase *lbase = static_cast<LBase *>(base_);
    770       return lbase->LookAheadWeight();
    771     } else {
    772       return Weight::One();
    773     }
    774   }
    775 
    776   bool LookAheadPrefix(Arc *arc) const {
    777     if (LookAheadCheck()) {
    778       LBase *lbase = static_cast<LBase *>(base_);
    779       return lbase->LookAheadPrefix(arc);
    780     } else {
    781       return false;
    782     }
    783   }
    784 
    785   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
    786     if (LookAheadCheck()) {
    787       LBase *lbase = static_cast<LBase *>(base_);
    788       lbase->InitLookAheadFst(fst, copy);
    789     }
    790   }
    791 
    792  private:
    793   bool LookAheadCheck() const {
    794     if (!lookahead_) {
    795       lookahead_ = base_->Flags() &
    796           (kInputLookAheadMatcher | kOutputLookAheadMatcher);
    797       if (!lookahead_) {
    798         FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
    799       }
    800     }
    801     return lookahead_;
    802   }
    803 
    804   MatcherBase<Arc> *base_;
    805   mutable bool lookahead_;
    806 
    807   void operator=(const LookAheadMatcher<Arc> &);  // disallow
    808 };
    809 
    810 }  // namespace fst
    811 
    812 #endif  // FST_LIB_LOOKAHEAD_MATCHER_H__
    813