Home | History | Annotate | Download | only in lib
      1 // factor-weight.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Classes to factor weights in an FST.
     19 
     20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
     21 #define FST_LIB_FACTOR_WEIGHT_H__
     22 
     23 #include <algorithm>
     24 
     25 #include <unordered_map>
     26 #include <forward_list>
     27 
     28 #include "fst/lib/cache.h"
     29 #include "fst/lib/test-properties.h"
     30 
     31 namespace fst {
     32 
     33 struct FactorWeightOptions : CacheOptions {
     34   float delta;
     35   bool final_only;  // only factor final weights when true
     36 
     37   FactorWeightOptions(const CacheOptions &opts, float d, bool of)
     38       : CacheOptions(opts), delta(d), final_only(of) {}
     39 
     40   explicit FactorWeightOptions(float d, bool of = false)
     41       : delta(d), final_only(of) {}
     42 
     43   FactorWeightOptions(bool of = false)
     44       : delta(kDelta), final_only(of) {}
     45 };
     46 
     47 
     48 // A factor iterator takes as argument a weight w and returns a
     49 // sequence of pairs of weights (xi,yi) such that the sum of the
     50 // products xi times yi is equal to w. If w is fully factored,
     51 // the iterator should return nothing.
     52 //
     53 // template <class W>
     54 // class FactorIterator {
     55 //  public:
     56 //   FactorIterator(W w);
     57 //   bool Done() const;
     58 //   void Next();
     59 //   pair<W, W> Value() const;
     60 //   void Reset();
     61 // }
     62 
     63 
     64 // Factor trivially.
     65 template <class W>
     66 class IdentityFactor {
     67  public:
     68   IdentityFactor(const W &w) {}
     69   bool Done() const { return true; }
     70   void Next() {}
     71   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
     72   void Reset() {}
     73 };
     74 
     75 
     76 // Factor a StringWeight w as 'ab' where 'a' is a label.
     77 template <typename L, StringType S = STRING_LEFT>
     78 class StringFactor {
     79  public:
     80   StringFactor(const StringWeight<L, S> &w)
     81       : weight_(w), done_(w.Size() <= 1) {}
     82 
     83   bool Done() const { return done_; }
     84 
     85   void Next() { done_ = true; }
     86 
     87   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
     88     StringWeightIterator<L, S> iter(weight_);
     89     StringWeight<L, S> w1(iter.Value());
     90     StringWeight<L, S> w2;
     91     for (iter.Next(); !iter.Done(); iter.Next())
     92       w2.PushBack(iter.Value());
     93     return make_pair(w1, w2);
     94   }
     95 
     96   void Reset() { done_ = weight_.Size() <= 1; }
     97 
     98  private:
     99   StringWeight<L, S> weight_;
    100   bool done_;
    101 };
    102 
    103 
    104 // Factor a GallicWeight using StringFactor.
    105 template <class L, class W, StringType S = STRING_LEFT>
    106 class GallicFactor {
    107  public:
    108   GallicFactor(const GallicWeight<L, W, S> &w)
    109       : weight_(w), done_(w.Value1().Size() <= 1) {}
    110 
    111   bool Done() const { return done_; }
    112 
    113   void Next() { done_ = true; }
    114 
    115   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
    116     StringFactor<L, S> iter(weight_.Value1());
    117     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
    118     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
    119     return make_pair(w1, w2);
    120   }
    121 
    122   void Reset() { done_ = weight_.Value1().Size() <= 1; }
    123 
    124  private:
    125   GallicWeight<L, W, S> weight_;
    126   bool done_;
    127 };
    128 
    129 
    130 // Implementation class for FactorWeight
    131 template <class A, class F>
    132 class FactorWeightFstImpl
    133     : public CacheImpl<A> {
    134  public:
    135   using FstImpl<A>::SetType;
    136   using FstImpl<A>::SetProperties;
    137   using FstImpl<A>::Properties;
    138   using FstImpl<A>::SetInputSymbols;
    139   using FstImpl<A>::SetOutputSymbols;
    140 
    141   using CacheBaseImpl< CacheState<A> >::HasStart;
    142   using CacheBaseImpl< CacheState<A> >::HasFinal;
    143   using CacheBaseImpl< CacheState<A> >::HasArcs;
    144 
    145   typedef A Arc;
    146   typedef typename A::Label Label;
    147   typedef typename A::Weight Weight;
    148   typedef typename A::StateId StateId;
    149   typedef F FactorIterator;
    150 
    151   struct Element {
    152     Element() {}
    153 
    154     Element(StateId s, Weight w) : state(s), weight(w) {}
    155 
    156     StateId state;     // Input state Id
    157     Weight weight;     // Residual weight
    158   };
    159 
    160   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
    161       : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
    162         final_only_(opts.final_only) {
    163     SetType("factor-weight");
    164     uint64 props = fst.Properties(kFstProperties, false);
    165     SetProperties(FactorWeightProperties(props), kCopyProperties);
    166 
    167     SetInputSymbols(fst.InputSymbols());
    168     SetOutputSymbols(fst.OutputSymbols());
    169   }
    170 
    171   ~FactorWeightFstImpl() {
    172     delete fst_;
    173   }
    174 
    175   StateId Start() {
    176     if (!HasStart()) {
    177       StateId s = fst_->Start();
    178       if (s == kNoStateId)
    179         return kNoStateId;
    180       StateId start = FindState(Element(fst_->Start(), Weight::One()));
    181       this->SetStart(start);
    182     }
    183     return CacheImpl<A>::Start();
    184   }
    185 
    186   Weight Final(StateId s) {
    187     if (!HasFinal(s)) {
    188       const Element &e = elements_[s];
    189       // TODO: fix so cast is unnecessary
    190       Weight w = e.state == kNoStateId
    191                  ? e.weight
    192                  : (Weight) Times(e.weight, fst_->Final(e.state));
    193       FactorIterator f(w);
    194       if (w != Weight::Zero() && f.Done())
    195         this->SetFinal(s, w);
    196       else
    197         this->SetFinal(s, Weight::Zero());
    198     }
    199     return CacheImpl<A>::Final(s);
    200   }
    201 
    202   size_t NumArcs(StateId s) {
    203     if (!HasArcs(s))
    204       Expand(s);
    205     return CacheImpl<A>::NumArcs(s);
    206   }
    207 
    208   size_t NumInputEpsilons(StateId s) {
    209     if (!HasArcs(s))
    210       Expand(s);
    211     return CacheImpl<A>::NumInputEpsilons(s);
    212   }
    213 
    214   size_t NumOutputEpsilons(StateId s) {
    215     if (!HasArcs(s))
    216       Expand(s);
    217     return CacheImpl<A>::NumOutputEpsilons(s);
    218   }
    219 
    220   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    221     if (!HasArcs(s))
    222       Expand(s);
    223     CacheImpl<A>::InitArcIterator(s, data);
    224   }
    225 
    226 
    227   // Find state corresponding to an element. Create new state
    228   // if element not found.
    229   StateId FindState(const Element &e) {
    230     if (final_only_ && e.weight == Weight::One()) {
    231       while (unfactored_.size() <= (unsigned int)e.state)
    232         unfactored_.push_back(kNoStateId);
    233       if (unfactored_[e.state] == kNoStateId) {
    234         unfactored_[e.state] = elements_.size();
    235         elements_.push_back(e);
    236       }
    237       return unfactored_[e.state];
    238     } else {
    239       typename ElementMap::iterator eit = element_map_.find(e);
    240       if (eit != element_map_.end()) {
    241         return (*eit).second;
    242       } else {
    243         StateId s = elements_.size();
    244         elements_.push_back(e);
    245         element_map_.insert(pair<const Element, StateId>(e, s));
    246         return s;
    247       }
    248     }
    249   }
    250 
    251   // Computes the outgoing transitions from a state, creating new destination
    252   // states as needed.
    253   void Expand(StateId s) {
    254     Element e = elements_[s];
    255     if (e.state != kNoStateId) {
    256       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    257            !ait.Done();
    258            ait.Next()) {
    259         const A &arc = ait.Value();
    260         Weight w = Times(e.weight, arc.weight);
    261         FactorIterator fit(w);
    262         if (final_only_ || fit.Done()) {
    263           StateId d = FindState(Element(arc.nextstate, Weight::One()));
    264           this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
    265         } else {
    266           for (; !fit.Done(); fit.Next()) {
    267             const pair<Weight, Weight> &p = fit.Value();
    268             StateId d = FindState(Element(arc.nextstate,
    269                                           p.second.Quantize(delta_)));
    270             this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
    271           }
    272         }
    273       }
    274     }
    275     if ((e.state == kNoStateId) ||
    276         (fst_->Final(e.state) != Weight::Zero())) {
    277       Weight w = e.state == kNoStateId
    278                  ? e.weight
    279                  : Times(e.weight, fst_->Final(e.state));
    280       for (FactorIterator fit(w);
    281            !fit.Done();
    282            fit.Next()) {
    283         const pair<Weight, Weight> &p = fit.Value();
    284         StateId d = FindState(Element(kNoStateId,
    285                                       p.second.Quantize(delta_)));
    286         this->AddArc(s, Arc(0, 0, p.first, d));
    287       }
    288     }
    289     this->SetArcs(s);
    290   }
    291 
    292  private:
    293   // Equality function for Elements, assume weights have been quantized.
    294   class ElementEqual {
    295    public:
    296     bool operator()(const Element &x, const Element &y) const {
    297       return x.state == y.state && x.weight == y.weight;
    298     }
    299   };
    300 
    301   // Hash function for Elements to Fst states.
    302   class ElementKey {
    303    public:
    304     size_t operator()(const Element &x) const {
    305       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
    306     }
    307    private:
    308     static const int kPrime = 7853;
    309   };
    310 
    311   typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    312 
    313   const Fst<A> *fst_;
    314   float delta_;
    315   bool final_only_;
    316   vector<Element> elements_;  // mapping Fst state to Elements
    317   ElementMap element_map_;    // mapping Elements to Fst state
    318   // mapping between old/new 'StateId' for states that do not need to
    319   // be factored when 'final_only_' is true
    320   vector<StateId> unfactored_;
    321 
    322   DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
    323 };
    324 
    325 
    326 // FactorWeightFst takes as template parameter a FactorIterator as
    327 // defined above. The result of weight factoring is a transducer
    328 // equivalent to the input whose path weights have been factored
    329 // according to the FactorIterator. States and transitions will be
    330 // added as necessary. The algorithm is a generalization to arbitrary
    331 // weights of the second step of the input epsilon-normalization
    332 // algorithm due to Mohri, "Generic epsilon-removal and input
    333 // epsilon-normalization algorithms for weighted transducers",
    334 // International Journal of Computer Science 13(1): 129-143 (2002).
    335 template <class A, class F>
    336 class FactorWeightFst : public Fst<A> {
    337  public:
    338   friend class ArcIterator< FactorWeightFst<A, F> >;
    339   friend class CacheStateIterator< FactorWeightFst<A, F> >;
    340   friend class CacheArcIterator< FactorWeightFst<A, F> >;
    341 
    342   typedef A Arc;
    343   typedef typename A::Weight Weight;
    344   typedef typename A::StateId StateId;
    345   typedef CacheState<A> State;
    346 
    347   FactorWeightFst(const Fst<A> &fst)
    348       : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
    349 
    350   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
    351       : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
    352   FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
    353     impl_->IncrRefCount();
    354   }
    355 
    356   virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    357 
    358   virtual StateId Start() const { return impl_->Start(); }
    359 
    360   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    361 
    362   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    363 
    364   virtual size_t NumInputEpsilons(StateId s) const {
    365     return impl_->NumInputEpsilons(s);
    366   }
    367 
    368   virtual size_t NumOutputEpsilons(StateId s) const {
    369     return impl_->NumOutputEpsilons(s);
    370   }
    371 
    372   virtual uint64 Properties(uint64 mask, bool test) const {
    373     if (test) {
    374       uint64 known, test = TestProperties(*this, mask, &known);
    375       impl_->SetProperties(test, known);
    376       return test & mask;
    377     } else {
    378       return impl_->Properties(mask);
    379     }
    380   }
    381 
    382   virtual const string& Type() const { return impl_->Type(); }
    383 
    384   virtual FactorWeightFst<A, F> *Copy() const {
    385     return new FactorWeightFst<A, F>(*this);
    386   }
    387 
    388   virtual const SymbolTable* InputSymbols() const {
    389     return impl_->InputSymbols();
    390   }
    391 
    392   virtual const SymbolTable* OutputSymbols() const {
    393     return impl_->OutputSymbols();
    394   }
    395 
    396   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    397 
    398   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    399     impl_->InitArcIterator(s, data);
    400   }
    401 
    402  private:
    403   FactorWeightFstImpl<A, F> *Impl() { return impl_; }
    404 
    405   FactorWeightFstImpl<A, F> *impl_;
    406 
    407   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
    408 };
    409 
    410 
    411 // Specialization for FactorWeightFst.
    412 template<class A, class F>
    413 class StateIterator< FactorWeightFst<A, F> >
    414     : public CacheStateIterator< FactorWeightFst<A, F> > {
    415  public:
    416   explicit StateIterator(const FactorWeightFst<A, F> &fst)
    417       : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
    418 };
    419 
    420 
    421 // Specialization for FactorWeightFst.
    422 template <class A, class F>
    423 class ArcIterator< FactorWeightFst<A, F> >
    424     : public CacheArcIterator< FactorWeightFst<A, F> > {
    425  public:
    426   typedef typename A::StateId StateId;
    427 
    428   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
    429       : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
    430     if (!fst.impl_->HasArcs(s))
    431       fst.impl_->Expand(s);
    432   }
    433 
    434  private:
    435   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    436 };
    437 
    438 template <class A, class F> inline
    439 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
    440 {
    441   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
    442 }
    443 
    444 
    445 }  // namespace fst
    446 
    447 #endif // FST_LIB_FACTOR_WEIGHT_H__
    448