Home | History | Annotate | Download | only in lib
      1 // factor-weight.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Classes to factor weights in an FST.
     19 
     20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
     21 #define FST_LIB_FACTOR_WEIGHT_H__
     22 
     23 #include <algorithm>
     24 
     25 #include <ext/hash_map>
     26 using __gnu_cxx::hash_map;
     27 #include <ext/slist>
     28 using __gnu_cxx::slist;
     29 
     30 #include "fst/lib/cache.h"
     31 #include "fst/lib/test-properties.h"
     32 
     33 namespace fst {
     34 
     35 struct FactorWeightOptions : CacheOptions {
     36   float delta;
     37   bool final_only;  // only factor final weights when true
     38 
     39   FactorWeightOptions(const CacheOptions &opts, float d, bool of)
     40       : CacheOptions(opts), delta(d), final_only(of) {}
     41 
     42   explicit FactorWeightOptions(float d, bool of = false)
     43       : delta(d), final_only(of) {}
     44 
     45   FactorWeightOptions(bool of = false)
     46       : delta(kDelta), final_only(of) {}
     47 };
     48 
     49 
     50 // A factor iterator takes as argument a weight w and returns a
     51 // sequence of pairs of weights (xi,yi) such that the sum of the
     52 // products xi times yi is equal to w. If w is fully factored,
     53 // the iterator should return nothing.
     54 //
     55 // template <class W>
     56 // class FactorIterator {
     57 //  public:
     58 //   FactorIterator(W w);
     59 //   bool Done() const;
     60 //   void Next();
     61 //   pair<W, W> Value() const;
     62 //   void Reset();
     63 // }
     64 
     65 
     66 // Factor trivially.
     67 template <class W>
     68 class IdentityFactor {
     69  public:
     70   IdentityFactor(const W &w) {}
     71   bool Done() const { return true; }
     72   void Next() {}
     73   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
     74   void Reset() {}
     75 };
     76 
     77 
     78 // Factor a StringWeight w as 'ab' where 'a' is a label.
     79 template <typename L, StringType S = STRING_LEFT>
     80 class StringFactor {
     81  public:
     82   StringFactor(const StringWeight<L, S> &w)
     83       : weight_(w), done_(w.Size() <= 1) {}
     84 
     85   bool Done() const { return done_; }
     86 
     87   void Next() { done_ = true; }
     88 
     89   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
     90     StringWeightIterator<L, S> iter(weight_);
     91     StringWeight<L, S> w1(iter.Value());
     92     StringWeight<L, S> w2;
     93     for (iter.Next(); !iter.Done(); iter.Next())
     94       w2.PushBack(iter.Value());
     95     return make_pair(w1, w2);
     96   }
     97 
     98   void Reset() { done_ = weight_.Size() <= 1; }
     99 
    100  private:
    101   StringWeight<L, S> weight_;
    102   bool done_;
    103 };
    104 
    105 
    106 // Factor a GallicWeight using StringFactor.
    107 template <class L, class W, StringType S = STRING_LEFT>
    108 class GallicFactor {
    109  public:
    110   GallicFactor(const GallicWeight<L, W, S> &w)
    111       : weight_(w), done_(w.Value1().Size() <= 1) {}
    112 
    113   bool Done() const { return done_; }
    114 
    115   void Next() { done_ = true; }
    116 
    117   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
    118     StringFactor<L, S> iter(weight_.Value1());
    119     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
    120     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
    121     return make_pair(w1, w2);
    122   }
    123 
    124   void Reset() { done_ = weight_.Value1().Size() <= 1; }
    125 
    126  private:
    127   GallicWeight<L, W, S> weight_;
    128   bool done_;
    129 };
    130 
    131 
    132 // Implementation class for FactorWeight
    133 template <class A, class F>
    134 class FactorWeightFstImpl
    135     : public CacheImpl<A> {
    136  public:
    137   using FstImpl<A>::SetType;
    138   using FstImpl<A>::SetProperties;
    139   using FstImpl<A>::Properties;
    140   using FstImpl<A>::SetInputSymbols;
    141   using FstImpl<A>::SetOutputSymbols;
    142 
    143   using CacheBaseImpl< CacheState<A> >::HasStart;
    144   using CacheBaseImpl< CacheState<A> >::HasFinal;
    145   using CacheBaseImpl< CacheState<A> >::HasArcs;
    146 
    147   typedef A Arc;
    148   typedef typename A::Label Label;
    149   typedef typename A::Weight Weight;
    150   typedef typename A::StateId StateId;
    151   typedef F FactorIterator;
    152 
    153   struct Element {
    154     Element() {}
    155 
    156     Element(StateId s, Weight w) : state(s), weight(w) {}
    157 
    158     StateId state;     // Input state Id
    159     Weight weight;     // Residual weight
    160   };
    161 
    162   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
    163       : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
    164         final_only_(opts.final_only) {
    165     SetType("factor-weight");
    166     uint64 props = fst.Properties(kFstProperties, false);
    167     SetProperties(FactorWeightProperties(props), kCopyProperties);
    168 
    169     SetInputSymbols(fst.InputSymbols());
    170     SetOutputSymbols(fst.OutputSymbols());
    171   }
    172 
    173   ~FactorWeightFstImpl() {
    174     delete fst_;
    175   }
    176 
    177   StateId Start() {
    178     if (!HasStart()) {
    179       StateId s = fst_->Start();
    180       if (s == kNoStateId)
    181         return kNoStateId;
    182       StateId start = FindState(Element(fst_->Start(), Weight::One()));
    183       this->SetStart(start);
    184     }
    185     return CacheImpl<A>::Start();
    186   }
    187 
    188   Weight Final(StateId s) {
    189     if (!HasFinal(s)) {
    190       const Element &e = elements_[s];
    191       // TODO: fix so cast is unnecessary
    192       Weight w = e.state == kNoStateId
    193                  ? e.weight
    194                  : (Weight) Times(e.weight, fst_->Final(e.state));
    195       FactorIterator f(w);
    196       if (w != Weight::Zero() && f.Done())
    197         this->SetFinal(s, w);
    198       else
    199         this->SetFinal(s, Weight::Zero());
    200     }
    201     return CacheImpl<A>::Final(s);
    202   }
    203 
    204   size_t NumArcs(StateId s) {
    205     if (!HasArcs(s))
    206       Expand(s);
    207     return CacheImpl<A>::NumArcs(s);
    208   }
    209 
    210   size_t NumInputEpsilons(StateId s) {
    211     if (!HasArcs(s))
    212       Expand(s);
    213     return CacheImpl<A>::NumInputEpsilons(s);
    214   }
    215 
    216   size_t NumOutputEpsilons(StateId s) {
    217     if (!HasArcs(s))
    218       Expand(s);
    219     return CacheImpl<A>::NumOutputEpsilons(s);
    220   }
    221 
    222   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    223     if (!HasArcs(s))
    224       Expand(s);
    225     CacheImpl<A>::InitArcIterator(s, data);
    226   }
    227 
    228 
    229   // Find state corresponding to an element. Create new state
    230   // if element not found.
    231   StateId FindState(const Element &e) {
    232     if (final_only_ && e.weight == Weight::One()) {
    233       while (unfactored_.size() <= (unsigned int)e.state)
    234         unfactored_.push_back(kNoStateId);
    235       if (unfactored_[e.state] == kNoStateId) {
    236         unfactored_[e.state] = elements_.size();
    237         elements_.push_back(e);
    238       }
    239       return unfactored_[e.state];
    240     } else {
    241       typename ElementMap::iterator eit = element_map_.find(e);
    242       if (eit != element_map_.end()) {
    243         return (*eit).second;
    244       } else {
    245         StateId s = elements_.size();
    246         elements_.push_back(e);
    247         element_map_.insert(pair<const Element, StateId>(e, s));
    248         return s;
    249       }
    250     }
    251   }
    252 
    253   // Computes the outgoing transitions from a state, creating new destination
    254   // states as needed.
    255   void Expand(StateId s) {
    256     Element e = elements_[s];
    257     if (e.state != kNoStateId) {
    258       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    259            !ait.Done();
    260            ait.Next()) {
    261         const A &arc = ait.Value();
    262         Weight w = Times(e.weight, arc.weight);
    263         FactorIterator fit(w);
    264         if (final_only_ || fit.Done()) {
    265           StateId d = FindState(Element(arc.nextstate, Weight::One()));
    266           this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
    267         } else {
    268           for (; !fit.Done(); fit.Next()) {
    269             const pair<Weight, Weight> &p = fit.Value();
    270             StateId d = FindState(Element(arc.nextstate,
    271                                           p.second.Quantize(delta_)));
    272             this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
    273           }
    274         }
    275       }
    276     }
    277     if ((e.state == kNoStateId) ||
    278         (fst_->Final(e.state) != Weight::Zero())) {
    279       Weight w = e.state == kNoStateId
    280                  ? e.weight
    281                  : Times(e.weight, fst_->Final(e.state));
    282       for (FactorIterator fit(w);
    283            !fit.Done();
    284            fit.Next()) {
    285         const pair<Weight, Weight> &p = fit.Value();
    286         StateId d = FindState(Element(kNoStateId,
    287                                       p.second.Quantize(delta_)));
    288         this->AddArc(s, Arc(0, 0, p.first, d));
    289       }
    290     }
    291     this->SetArcs(s);
    292   }
    293 
    294  private:
    295   // Equality function for Elements, assume weights have been quantized.
    296   class ElementEqual {
    297    public:
    298     bool operator()(const Element &x, const Element &y) const {
    299       return x.state == y.state && x.weight == y.weight;
    300     }
    301   };
    302 
    303   // Hash function for Elements to Fst states.
    304   class ElementKey {
    305    public:
    306     size_t operator()(const Element &x) const {
    307       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
    308     }
    309    private:
    310     static const int kPrime = 7853;
    311   };
    312 
    313   typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    314 
    315   const Fst<A> *fst_;
    316   float delta_;
    317   bool final_only_;
    318   vector<Element> elements_;  // mapping Fst state to Elements
    319   ElementMap element_map_;    // mapping Elements to Fst state
    320   // mapping between old/new 'StateId' for states that do not need to
    321   // be factored when 'final_only_' is true
    322   vector<StateId> unfactored_;
    323 
    324   DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
    325 };
    326 
    327 
    328 // FactorWeightFst takes as template parameter a FactorIterator as
    329 // defined above. The result of weight factoring is a transducer
    330 // equivalent to the input whose path weights have been factored
    331 // according to the FactorIterator. States and transitions will be
    332 // added as necessary. The algorithm is a generalization to arbitrary
    333 // weights of the second step of the input epsilon-normalization
    334 // algorithm due to Mohri, "Generic epsilon-removal and input
    335 // epsilon-normalization algorithms for weighted transducers",
    336 // International Journal of Computer Science 13(1): 129-143 (2002).
    337 template <class A, class F>
    338 class FactorWeightFst : public Fst<A> {
    339  public:
    340   friend class ArcIterator< FactorWeightFst<A, F> >;
    341   friend class CacheStateIterator< FactorWeightFst<A, F> >;
    342   friend class CacheArcIterator< FactorWeightFst<A, F> >;
    343 
    344   typedef A Arc;
    345   typedef typename A::Weight Weight;
    346   typedef typename A::StateId StateId;
    347   typedef CacheState<A> State;
    348 
    349   FactorWeightFst(const Fst<A> &fst)
    350       : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
    351 
    352   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
    353       : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
    354   FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
    355     impl_->IncrRefCount();
    356   }
    357 
    358   virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    359 
    360   virtual StateId Start() const { return impl_->Start(); }
    361 
    362   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    363 
    364   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    365 
    366   virtual size_t NumInputEpsilons(StateId s) const {
    367     return impl_->NumInputEpsilons(s);
    368   }
    369 
    370   virtual size_t NumOutputEpsilons(StateId s) const {
    371     return impl_->NumOutputEpsilons(s);
    372   }
    373 
    374   virtual uint64 Properties(uint64 mask, bool test) const {
    375     if (test) {
    376       uint64 known, test = TestProperties(*this, mask, &known);
    377       impl_->SetProperties(test, known);
    378       return test & mask;
    379     } else {
    380       return impl_->Properties(mask);
    381     }
    382   }
    383 
    384   virtual const string& Type() const { return impl_->Type(); }
    385 
    386   virtual FactorWeightFst<A, F> *Copy() const {
    387     return new FactorWeightFst<A, F>(*this);
    388   }
    389 
    390   virtual const SymbolTable* InputSymbols() const {
    391     return impl_->InputSymbols();
    392   }
    393 
    394   virtual const SymbolTable* OutputSymbols() const {
    395     return impl_->OutputSymbols();
    396   }
    397 
    398   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    399 
    400   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    401     impl_->InitArcIterator(s, data);
    402   }
    403 
    404  private:
    405   FactorWeightFstImpl<A, F> *Impl() { return impl_; }
    406 
    407   FactorWeightFstImpl<A, F> *impl_;
    408 
    409   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
    410 };
    411 
    412 
    413 // Specialization for FactorWeightFst.
    414 template<class A, class F>
    415 class StateIterator< FactorWeightFst<A, F> >
    416     : public CacheStateIterator< FactorWeightFst<A, F> > {
    417  public:
    418   explicit StateIterator(const FactorWeightFst<A, F> &fst)
    419       : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
    420 };
    421 
    422 
    423 // Specialization for FactorWeightFst.
    424 template <class A, class F>
    425 class ArcIterator< FactorWeightFst<A, F> >
    426     : public CacheArcIterator< FactorWeightFst<A, F> > {
    427  public:
    428   typedef typename A::StateId StateId;
    429 
    430   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
    431       : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
    432     if (!fst.impl_->HasArcs(s))
    433       fst.impl_->Expand(s);
    434   }
    435 
    436  private:
    437   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    438 };
    439 
    440 template <class A, class F> inline
    441 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
    442 {
    443   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
    444 }
    445 
    446 
    447 }  // namespace fst
    448 
    449 #endif // FST_LIB_FACTOR_WEIGHT_H__
    450