Home | History | Annotate | Download | only in fst
      1 // compose-filter.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes for filtering the composition matches, e.g. for correct epsilon
     20 // handling.
     21 
     22 #ifndef FST_LIB_COMPOSE_FILTER_H__
     23 #define FST_LIB_COMPOSE_FILTER_H__
     24 
     25 #include <fst/fst.h>
     26 #include <fst/fst-decl.h>  // For optional argument declarations
     27 #include <fst/matcher.h>
     28 
     29 
     30 namespace fst {
     31 
     32 
     33 // COMPOSITION FILTER STATE - this represents the state of
     34 // the composition filter. It has the form:
     35 //
     36 // class FilterState {
     37 //  public:
     38 //   // Required constructors
     39 //   FilterState();
     40 //   FilterState(const FilterState &f);
     41 //   // An invalid filter state.
     42 //   static const FilterState NoState();
     43 //   // Maps state to integer for hashing.
     44 //   size_t Hash() const;
     45 //   // Equality of filter states.
     46 //   bool operator==(const FilterState &f) const;
     47 //   // Inequality of filter states.
     48 //   bool operator!=(const FilterState &f) const;
     49 //   // Assignment to filter states.
     50 //   FilterState& operator=(const FilterState& f);
     51 // };
     52 
     53 
     54 // Filter state that is a signed integral type.
     55 template <typename T>
     56 class IntegerFilterState {
     57  public:
     58   IntegerFilterState() : state_(kNoStateId) {}
     59   explicit IntegerFilterState(T s) : state_(s) {}
     60 
     61   static const IntegerFilterState NoState() { return IntegerFilterState(); }
     62 
     63   size_t Hash() const { return static_cast<size_t>(state_); }
     64 
     65   bool operator==(const IntegerFilterState &f) const {
     66     return state_ == f.state_;
     67   }
     68 
     69   bool operator!=(const IntegerFilterState &f) const {
     70     return state_ != f.state_;
     71   }
     72 
     73   T GetState() const { return state_; }
     74 
     75   void SetState(T state) { state_ = state; }
     76 
     77 private:
     78   T state_;
     79 };
     80 
     81 typedef IntegerFilterState<signed char> CharFilterState;
     82 typedef IntegerFilterState<short> ShortFilterState;
     83 typedef IntegerFilterState<int> IntFilterState;
     84 
     85 
     86 // Filter state that is a weight (class).
     87 template <class W>
     88 class WeightFilterState {
     89  public:
     90   WeightFilterState() : weight_(W::Zero()) {}
     91   explicit WeightFilterState(W w) : weight_(w) {}
     92 
     93   static const WeightFilterState NoState() { return WeightFilterState(); }
     94 
     95   size_t Hash() const { return weight_.Hash(); }
     96 
     97   bool operator==(const WeightFilterState &f) const {
     98     return weight_ == f.weight_;
     99   }
    100 
    101   bool operator!=(const WeightFilterState &f) const {
    102     return weight_ != f.weight_;
    103   }
    104 
    105   W GetWeight() const { return weight_; }
    106 
    107   void SetWeight(W w) { weight_ = w; }
    108 
    109 private:
    110   W weight_;
    111 };
    112 
    113 
    114 // Filter state that is the combination of two filter states.
    115 template <class F1, class F2>
    116 class PairFilterState {
    117  public:
    118   PairFilterState() : f1_(F1::NoState()), f2_(F2::NoState()) {}
    119 
    120   PairFilterState(const F1 &f1, const F2 &f2) : f1_(f1), f2_(f2) {}
    121 
    122   static const PairFilterState NoState() { return PairFilterState(); }
    123 
    124   size_t Hash() const {
    125     size_t h1 = f1_.Hash();
    126     size_t h2 = f2_.Hash();
    127     const int lshift = 5;
    128     const int rshift = CHAR_BIT * sizeof(size_t) - 5;
    129     return h1 << lshift ^ h1 >> rshift ^ h2;
    130   }
    131 
    132   bool operator==(const PairFilterState &f) const {
    133     return f1_ == f.f1_ && f2_ == f.f2_;
    134   }
    135 
    136   bool operator!=(const PairFilterState &f) const {
    137     return f1_ != f.f1_ || f2_ != f.f2_;
    138   }
    139 
    140   const F1 &GetState1() const { return f1_; }
    141   const F2 &GetState2() const { return f2_; }
    142 
    143   void SetState(const F1 &f1, const F2 &f2) {
    144     f1_ = f1;
    145     f2_ = f2;
    146   }
    147 
    148 private:
    149   F1 f1_;
    150   F2 f2_;
    151 };
    152 
    153 
    154 // COMPOSITION FILTERS - these determine which matches are allowed to
    155 // proceed. The filter's state is represented by the type
    156 // ComposeFilter::FilterState. The basic filters handle correct
    157 // epsilon matching.  Their interface is:
    158 //
    159 // template <class M1, class M2>
    160 // class ComposeFilter {
    161 //  public:
    162 //   typedef typename M1::FST1 FST1;
    163 //   typedef typename M1::FST2 FST2;
    164 //   typedef typename FST1::Arc Arc;
    165 //   typedef ... FilterState;
    166 //   typedef ... Matcher1;
    167 //   typedef ... Matcher2;
    168 //
    169 //   // Required constructors.
    170 //   ComposeFilter(const FST1 &fst1, const FST2 &fst2,
    171 //   //            M1 *matcher1 = 0, M2 *matcher2 = 0);
    172 //   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
    173 //   // for further doc.
    174 //   ComposeFilter(const ComposeFilter<M1, M2> &filter,
    175 //   //            bool safe = false);
    176 //   // Return start state of filter.
    177 //   FilterState Start() const;
    178 //   // Specifies current composition state.
    179 //   void SetState(StateId s1, StateId s2, const FilterState &f);
    180 //
    181 //   // Apply filter at current composition state to these transitions.
    182 //   // If an arc label to be matched is kNolabel, then that side
    183 //   // does not consume a symbol. Returns the new filter state or,
    184 //   // if disallowed, FilterState::NoState(). The filter is permitted to
    185 //   // modify its inputs, e.g. for optimizations.
    186 //   FilterState FilterArc(Arc *arc1, Arc *arc2) const;
    187 
    188 //   // Apply filter at current composition state to these final weights
    189 //   // (cf. superfinal transitions). The filter may modify its inputs,
    190 //   // e.g. for optimizations.
    191 //   void FilterFinal(Weight *final1, Weight *final2) const;
    192 //
    193 //   // Return resp matchers. Ownership stays with filter. These
    194 //   // methods allow the filter to access and possibly modify
    195 //   // the composition matchers (useful e.g. with lookahead).
    196 //   Matcher1 *GetMatcher1();
    197 //   Matcher2 *GetMatcher2();
    198 //
    199 //   // This specifies how the filter affects the composition result
    200 //   // properties. It takes as argument the properties that would
    201 //   // apply with a trivial composition fitler.
    202 //   uint64 Properties(uint64 props) const;
    203 // };
    204 
    205 // This filter requires epsilons on FST1 to be read before epsilons on FST2.
    206 template <class M1, class M2>
    207 class SequenceComposeFilter {
    208  public:
    209   typedef typename M1::FST FST1;
    210   typedef typename M2::FST FST2;
    211   typedef typename FST1::Arc Arc;
    212   typedef CharFilterState FilterState;
    213   typedef M1 Matcher1;
    214   typedef M2 Matcher2;
    215 
    216   typedef typename Arc::StateId StateId;
    217   typedef typename Arc::Label Label;
    218   typedef typename Arc::Weight Weight;
    219 
    220   SequenceComposeFilter(const FST1 &fst1, const FST2 &fst2,
    221                         M1 *matcher1 = 0, M2 *matcher2 = 0)
    222       : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)),
    223         matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)),
    224         fst1_(matcher1_->GetFst()),
    225         s1_(kNoStateId),
    226         s2_(kNoStateId),
    227         f_(kNoStateId) {}
    228 
    229   SequenceComposeFilter(const SequenceComposeFilter<M1, M2> &filter,
    230                         bool safe = false)
    231       : matcher1_(filter.matcher1_->Copy(safe)),
    232         matcher2_(filter.matcher2_->Copy(safe)),
    233         fst1_(matcher1_->GetFst()),
    234         s1_(kNoStateId),
    235         s2_(kNoStateId),
    236         f_(kNoStateId) {}
    237 
    238   ~SequenceComposeFilter() {
    239     delete matcher1_;
    240     delete matcher2_;
    241   }
    242 
    243   FilterState Start() const { return FilterState(0); }
    244 
    245   void SetState(StateId s1, StateId s2, const FilterState &f) {
    246     if (s1_ == s1 && s2_ == s2 && f == f_)
    247       return;
    248     s1_ = s1;
    249     s2_ = s2;
    250     f_ = f;
    251     size_t na1 = internal::NumArcs(fst1_, s1);
    252     size_t ne1 = internal::NumOutputEpsilons(fst1_, s1);
    253     bool fin1 = internal::Final(fst1_, s1) != Weight::Zero();
    254     alleps1_ = na1 == ne1 && !fin1;
    255     noeps1_ = ne1 == 0;
    256   }
    257 
    258   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    259     if (arc1->olabel == kNoLabel)
    260       return alleps1_ ? FilterState::NoState() :
    261         noeps1_ ? FilterState(0) : FilterState(1);
    262     else if (arc2->ilabel == kNoLabel)
    263       return f_ != FilterState(0) ? FilterState::NoState() : FilterState(0);
    264     else
    265       return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0);
    266   }
    267 
    268   void FilterFinal(Weight *, Weight *) const {}
    269 
    270   // Return resp matchers. Ownership stays with filter.
    271   Matcher1 *GetMatcher1() { return matcher1_; }
    272   Matcher2 *GetMatcher2() { return matcher2_; }
    273 
    274   uint64 Properties(uint64 props) const { return props; }
    275 
    276  private:
    277   Matcher1 *matcher1_;
    278   Matcher2 *matcher2_;
    279   const FST1 &fst1_;
    280   StateId s1_;     // Current fst1_ state;
    281   StateId s2_;     // Current fst2_ state;
    282   FilterState f_;  // Current filter state
    283   bool alleps1_;   // Only epsilons (and non-final) leaving s1_?
    284   bool noeps1_;    // No epsilons leaving s1_?
    285 
    286   void operator=(const SequenceComposeFilter<M1, M2> &);  // disallow
    287 };
    288 
    289 
    290 // This filter requires epsilons on FST2 to be read before epsilons on FST1.
    291 template <class M1, class M2>
    292 class AltSequenceComposeFilter {
    293  public:
    294   typedef typename M1::FST FST1;
    295   typedef typename M2::FST FST2;
    296   typedef typename FST1::Arc Arc;
    297   typedef CharFilterState FilterState;
    298   typedef M1 Matcher1;
    299   typedef M2 Matcher2;
    300 
    301   typedef typename Arc::StateId StateId;
    302   typedef typename Arc::Label Label;
    303   typedef typename Arc::Weight Weight;
    304 
    305   AltSequenceComposeFilter(const FST1 &fst1, const FST2 &fst2,
    306                         M1 *matcher1 = 0, M2 *matcher2 = 0)
    307       : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)),
    308         matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)),
    309         fst2_(matcher2_->GetFst()),
    310         s1_(kNoStateId),
    311         s2_(kNoStateId),
    312         f_(kNoStateId) {}
    313 
    314   AltSequenceComposeFilter(const AltSequenceComposeFilter<M1, M2> &filter,
    315                            bool safe = false)
    316       : matcher1_(filter.matcher1_->Copy(safe)),
    317         matcher2_(filter.matcher2_->Copy(safe)),
    318         fst2_(matcher2_->GetFst()),
    319         s1_(kNoStateId),
    320         s2_(kNoStateId),
    321         f_(kNoStateId) {}
    322 
    323   ~AltSequenceComposeFilter() {
    324     delete matcher1_;
    325     delete matcher2_;
    326   }
    327 
    328   FilterState Start() const { return FilterState(0); }
    329 
    330   void SetState(StateId s1, StateId s2, const FilterState &f) {
    331     if (s1_ == s1 && s2_ == s2 && f == f_)
    332       return;
    333     s1_ = s1;
    334     s2_ = s2;
    335     f_ = f;
    336     size_t na2 = internal::NumArcs(fst2_, s2);
    337     size_t ne2 = internal::NumInputEpsilons(fst2_, s2);
    338     bool fin2 = internal::Final(fst2_, s2) != Weight::Zero();
    339     alleps2_ = na2 == ne2 && !fin2;
    340     noeps2_ = ne2 == 0;
    341   }
    342 
    343   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    344     if (arc2->ilabel == kNoLabel)
    345       return alleps2_ ? FilterState::NoState() :
    346         noeps2_ ? FilterState(0) : FilterState(1);
    347     else if (arc1->olabel == kNoLabel)
    348       return f_ == FilterState(1) ? FilterState::NoState() : FilterState(0);
    349     else
    350       return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0);
    351   }
    352 
    353   void FilterFinal(Weight *, Weight *) const {}
    354 
    355   // Return resp matchers. Ownership stays with filter.
    356   Matcher1 *GetMatcher1() { return matcher1_; }
    357   Matcher2 *GetMatcher2() { return matcher2_; }
    358 
    359   uint64 Properties(uint64 props) const { return props; }
    360 
    361  private:
    362   Matcher1 *matcher1_;
    363   Matcher2 *matcher2_;
    364   const FST2 &fst2_;
    365   StateId s1_;     // Current fst1_ state;
    366   StateId s2_;     // Current fst2_ state;
    367   FilterState f_;  // Current filter state
    368   bool alleps2_;   // Only epsilons (and non-final) leaving s2_?
    369   bool noeps2_;    // No epsilons leaving s2_?
    370 
    371 void operator=(const AltSequenceComposeFilter<M1, M2> &);  // disallow
    372 };
    373 
    374 
    375 // This filter requires epsilons on FST1 to be matched with epsilons on FST2
    376 // whenever possible.
    377 template <class M1, class M2>
    378 class MatchComposeFilter {
    379  public:
    380   typedef typename M1::FST FST1;
    381   typedef typename M2::FST FST2;
    382   typedef typename FST1::Arc Arc;
    383   typedef CharFilterState FilterState;
    384   typedef M1 Matcher1;
    385   typedef M2 Matcher2;
    386 
    387   typedef typename Arc::StateId StateId;
    388   typedef typename Arc::Label Label;
    389   typedef typename Arc::Weight Weight;
    390 
    391   MatchComposeFilter(const FST1 &fst1, const FST2 &fst2,
    392                      M1 *matcher1 = 0, M2 *matcher2 = 0)
    393       : matcher1_(matcher1 ? matcher1 : new M1(fst1, MATCH_OUTPUT)),
    394         matcher2_(matcher2 ? matcher2 : new M2(fst2, MATCH_INPUT)),
    395         fst1_(matcher1_->GetFst()),
    396         fst2_(matcher2_->GetFst()),
    397         s1_(kNoStateId),
    398         s2_(kNoStateId),
    399         f_(kNoStateId) {}
    400 
    401   MatchComposeFilter(const MatchComposeFilter<M1, M2> &filter,
    402                      bool safe = false)
    403       : matcher1_(filter.matcher1_->Copy(safe)),
    404         matcher2_(filter.matcher2_->Copy(safe)),
    405         fst1_(matcher1_->GetFst()),
    406         fst2_(matcher2_->GetFst()),
    407         s1_(kNoStateId),
    408         s2_(kNoStateId),
    409         f_(kNoStateId) {}
    410 
    411   ~MatchComposeFilter() {
    412     delete matcher1_;
    413     delete matcher2_;
    414   }
    415 
    416   FilterState Start() const { return FilterState(0); }
    417 
    418   void SetState(StateId s1, StateId s2, const FilterState &f) {
    419     if (s1_ == s1 && s2_ == s2 && f == f_)
    420       return;
    421     s1_ = s1;
    422     s2_ = s2;
    423     f_ = f;
    424     size_t na1 = internal::NumArcs(fst1_, s1);
    425     size_t ne1 = internal::NumOutputEpsilons(fst1_, s1);
    426     bool f1 = internal::Final(fst1_, s1) != Weight::Zero();
    427     alleps1_ = na1 == ne1 && !f1;
    428     noeps1_ = ne1 == 0;
    429     size_t na2 = internal::NumArcs(fst2_, s2);
    430     size_t ne2 = internal::NumInputEpsilons(fst2_, s2);
    431     bool f2 = internal::Final(fst2_, s2) != Weight::Zero();
    432     alleps2_ = na2 == ne2 && !f2;
    433     noeps2_ = ne2 == 0;
    434   }
    435 
    436   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    437     if (arc2->ilabel == kNoLabel)  // Epsilon on Fst1
    438       return f_ == FilterState(0) ?
    439           (noeps2_ ? FilterState(0) :
    440            (alleps2_ ? FilterState::NoState(): FilterState(1))) :
    441           (f_ == FilterState(1) ? FilterState(1) : FilterState::NoState());
    442     else if (arc1->olabel == kNoLabel)  // Epsilon on Fst2
    443       return f_ == FilterState(0) ?
    444           (noeps1_ ? FilterState(0) :
    445            (alleps1_ ? FilterState::NoState() : FilterState(2))) :
    446           (f_ == FilterState(2) ? FilterState(2) : FilterState::NoState());
    447     else if (arc1->olabel == 0)  // Epsilon on both
    448       return f_ == FilterState(0) ? FilterState(0) : FilterState::NoState();
    449     else  // Both are non-epsilons
    450       return FilterState(0);
    451   }
    452 
    453   void FilterFinal(Weight *, Weight *) const {}
    454 
    455   // Return resp matchers. Ownership stays with filter.
    456   Matcher1 *GetMatcher1() { return matcher1_; }
    457   Matcher2 *GetMatcher2() { return matcher2_; }
    458 
    459   uint64 Properties(uint64 props) const { return props; }
    460 
    461  private:
    462   Matcher1 *matcher1_;
    463   Matcher2 *matcher2_;
    464   const FST1 &fst1_;
    465   const FST2 &fst2_;
    466   StateId s1_;              // Current fst1_ state;
    467   StateId s2_;              // Current fst2_ state;
    468   FilterState f_;           // Current filter state ID
    469   bool alleps1_, alleps2_;  // Only epsilons (and non-final) leaving s1, s2?
    470   bool noeps1_, noeps2_;    // No epsilons leaving s1, s2?
    471 
    472   void operator=(const MatchComposeFilter<M1, M2> &);  // disallow
    473 };
    474 
    475 
    476 // This filter works with the MultiEpsMatcher to determine if
    477 // 'multi-epsilons' are preserved in the composition output
    478 // (rather than rewritten as 0) and ensures correct properties.
    479 template <class F>
    480 class MultiEpsFilter {
    481  public:
    482   typedef typename F::FST1 FST1;
    483   typedef typename F::FST2 FST2;
    484   typedef typename F::Arc Arc;
    485   typedef typename F::Matcher1 Matcher1;
    486   typedef typename F::Matcher2 Matcher2;
    487   typedef typename F::FilterState FilterState;
    488   typedef MultiEpsFilter<F> Filter;
    489 
    490   typedef typename Arc::StateId StateId;
    491   typedef typename Arc::Label Label;
    492   typedef typename Arc::Weight Weight;
    493 
    494   MultiEpsFilter(const FST1 &fst1, const FST2 &fst2,
    495                  Matcher1 *matcher1 = 0,  Matcher2 *matcher2 = 0,
    496                  bool keep_multi_eps = false)
    497       : filter_(fst1, fst2, matcher1, matcher2),
    498         keep_multi_eps_(keep_multi_eps) {}
    499 
    500   MultiEpsFilter(const Filter &filter, bool safe = false)
    501       : filter_(filter.filter_, safe),
    502         keep_multi_eps_(filter.keep_multi_eps_) {}
    503 
    504   FilterState Start() const { return filter_.Start(); }
    505 
    506   void SetState(StateId s1, StateId s2, const FilterState &f) {
    507     return filter_.SetState(s1, s2, f);
    508   }
    509 
    510   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    511     FilterState f = filter_.FilterArc(arc1, arc2);
    512     if (keep_multi_eps_) {
    513       if (arc1->olabel == kNoLabel)
    514         arc1->ilabel = arc2->ilabel;
    515       if (arc2->ilabel == kNoLabel)
    516         arc2->olabel = arc1->olabel;
    517     }
    518     return f;
    519   }
    520 
    521   void FilterFinal(Weight *w1, Weight *w2) const {
    522     return filter_.FilterFinal(w1, w2);
    523   }
    524 
    525   // Return resp matchers. Ownership stays with filter.
    526   Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
    527   Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
    528 
    529   uint64 Properties(uint64 iprops) const {
    530     uint64 oprops = filter_.Properties(iprops);
    531     return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
    532   }
    533 
    534  private:
    535   F filter_;
    536   bool keep_multi_eps_;
    537 };
    538 
    539 }  // namespace fst
    540 
    541 
    542 #endif  // FST_LIB_COMPOSE_FILTER_H__
    543