Home | History | Annotate | Download | only in fst
      1 // factor-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Classes to factor weights in an FST.
     20 
     21 #ifndef FST_LIB_FACTOR_WEIGHT_H__
     22 #define FST_LIB_FACTOR_WEIGHT_H__
     23 
     24 #include <algorithm>
     25 #include <tr1/unordered_map>
     26 using std::tr1::unordered_map;
     27 using std::tr1::unordered_multimap;
     28 #include <string>
     29 #include <utility>
     30 using std::pair; using std::make_pair;
     31 #include <vector>
     32 using std::vector;
     33 
     34 #include <fst/cache.h>
     35 #include <fst/test-properties.h>
     36 
     37 
     38 namespace fst {
     39 
     40 const uint32 kFactorFinalWeights = 0x00000001;
     41 const uint32 kFactorArcWeights   = 0x00000002;
     42 
     43 template <class Arc>
     44 struct FactorWeightOptions : CacheOptions {
     45   typedef typename Arc::Label Label;
     46   float delta;
     47   uint32 mode;         // factor arc weights and/or final weights
     48   Label final_ilabel;  // input label of arc created when factoring final w's
     49   Label final_olabel;  // output label of arc created when factoring final w's
     50 
     51   FactorWeightOptions(const CacheOptions &opts, float d,
     52                       uint32 m = kFactorArcWeights | kFactorFinalWeights,
     53                       Label il = 0, Label ol = 0)
     54       : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
     55         final_olabel(ol) {}
     56 
     57   explicit FactorWeightOptions(
     58       float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
     59       Label il = 0, Label ol = 0)
     60       : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
     61 
     62   FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
     63                       Label il = 0, Label ol = 0)
     64       : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
     65 };
     66 
     67 
     68 // A factor iterator takes as argument a weight w and returns a
     69 // sequence of pairs of weights (xi,yi) such that the sum of the
     70 // products xi times yi is equal to w. If w is fully factored,
     71 // the iterator should return nothing.
     72 //
     73 // template <class W>
     74 // class FactorIterator {
     75 //  public:
     76 //   FactorIterator(W w);
     77 //   bool Done() const;
     78 //   void Next();
     79 //   pair<W, W> Value() const;
     80 //   void Reset();
     81 // }
     82 
     83 
     84 // Factor trivially.
     85 template <class W>
     86 class IdentityFactor {
     87  public:
     88   IdentityFactor(const W &w) {}
     89   bool Done() const { return true; }
     90   void Next() {}
     91   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
     92   void Reset() {}
     93 };
     94 
     95 
     96 // Factor a StringWeight w as 'ab' where 'a' is a label.
     97 template <typename L, StringType S = STRING_LEFT>
     98 class StringFactor {
     99  public:
    100   StringFactor(const StringWeight<L, S> &w)
    101       : weight_(w), done_(w.Size() <= 1) {}
    102 
    103   bool Done() const { return done_; }
    104 
    105   void Next() { done_ = true; }
    106 
    107   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
    108     StringWeightIterator<L, S> iter(weight_);
    109     StringWeight<L, S> w1(iter.Value());
    110     StringWeight<L, S> w2;
    111     for (iter.Next(); !iter.Done(); iter.Next())
    112       w2.PushBack(iter.Value());
    113     return make_pair(w1, w2);
    114   }
    115 
    116   void Reset() { done_ = weight_.Size() <= 1; }
    117 
    118  private:
    119   StringWeight<L, S> weight_;
    120   bool done_;
    121 };
    122 
    123 
    124 // Factor a GallicWeight using StringFactor.
    125 template <class L, class W, StringType S = STRING_LEFT>
    126 class GallicFactor {
    127  public:
    128   GallicFactor(const GallicWeight<L, W, S> &w)
    129       : weight_(w), done_(w.Value1().Size() <= 1) {}
    130 
    131   bool Done() const { return done_; }
    132 
    133   void Next() { done_ = true; }
    134 
    135   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
    136     StringFactor<L, S> iter(weight_.Value1());
    137     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
    138     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
    139     return make_pair(w1, w2);
    140   }
    141 
    142   void Reset() { done_ = weight_.Value1().Size() <= 1; }
    143 
    144  private:
    145   GallicWeight<L, W, S> weight_;
    146   bool done_;
    147 };
    148 
    149 
    150 // Implementation class for FactorWeight
    151 template <class A, class F>
    152 class FactorWeightFstImpl
    153     : public CacheImpl<A> {
    154  public:
    155   using FstImpl<A>::SetType;
    156   using FstImpl<A>::SetProperties;
    157   using FstImpl<A>::SetInputSymbols;
    158   using FstImpl<A>::SetOutputSymbols;
    159 
    160   using CacheBaseImpl< CacheState<A> >::PushArc;
    161   using CacheBaseImpl< CacheState<A> >::HasStart;
    162   using CacheBaseImpl< CacheState<A> >::HasFinal;
    163   using CacheBaseImpl< CacheState<A> >::HasArcs;
    164   using CacheBaseImpl< CacheState<A> >::SetArcs;
    165   using CacheBaseImpl< CacheState<A> >::SetFinal;
    166   using CacheBaseImpl< CacheState<A> >::SetStart;
    167 
    168   typedef A Arc;
    169   typedef typename A::Label Label;
    170   typedef typename A::Weight Weight;
    171   typedef typename A::StateId StateId;
    172   typedef F FactorIterator;
    173 
    174   struct Element {
    175     Element() {}
    176 
    177     Element(StateId s, Weight w) : state(s), weight(w) {}
    178 
    179     StateId state;     // Input state Id
    180     Weight weight;     // Residual weight
    181   };
    182 
    183   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
    184       : CacheImpl<A>(opts),
    185         fst_(fst.Copy()),
    186         delta_(opts.delta),
    187         mode_(opts.mode),
    188         final_ilabel_(opts.final_ilabel),
    189         final_olabel_(opts.final_olabel) {
    190     SetType("factor_weight");
    191     uint64 props = fst.Properties(kFstProperties, false);
    192     SetProperties(FactorWeightProperties(props), kCopyProperties);
    193 
    194     SetInputSymbols(fst.InputSymbols());
    195     SetOutputSymbols(fst.OutputSymbols());
    196 
    197     if (mode_ == 0)
    198       LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
    199                    << "factoring neither arc weights nor final weights.";
    200   }
    201 
    202   FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
    203       : CacheImpl<A>(impl),
    204         fst_(impl.fst_->Copy(true)),
    205         delta_(impl.delta_),
    206         mode_(impl.mode_),
    207         final_ilabel_(impl.final_ilabel_),
    208         final_olabel_(impl.final_olabel_) {
    209     SetType("factor_weight");
    210     SetProperties(impl.Properties(), kCopyProperties);
    211     SetInputSymbols(impl.InputSymbols());
    212     SetOutputSymbols(impl.OutputSymbols());
    213   }
    214 
    215   ~FactorWeightFstImpl() {
    216     delete fst_;
    217   }
    218 
    219   StateId Start() {
    220     if (!HasStart()) {
    221       StateId s = fst_->Start();
    222       if (s == kNoStateId)
    223         return kNoStateId;
    224       StateId start = FindState(Element(fst_->Start(), Weight::One()));
    225       SetStart(start);
    226     }
    227     return CacheImpl<A>::Start();
    228   }
    229 
    230   Weight Final(StateId s) {
    231     if (!HasFinal(s)) {
    232       const Element &e = elements_[s];
    233       // TODO: fix so cast is unnecessary
    234       Weight w = e.state == kNoStateId
    235                  ? e.weight
    236                  : (Weight) Times(e.weight, fst_->Final(e.state));
    237       FactorIterator f(w);
    238       if (!(mode_ & kFactorFinalWeights) || f.Done())
    239         SetFinal(s, w);
    240       else
    241         SetFinal(s, Weight::Zero());
    242     }
    243     return CacheImpl<A>::Final(s);
    244   }
    245 
    246   size_t NumArcs(StateId s) {
    247     if (!HasArcs(s))
    248       Expand(s);
    249     return CacheImpl<A>::NumArcs(s);
    250   }
    251 
    252   size_t NumInputEpsilons(StateId s) {
    253     if (!HasArcs(s))
    254       Expand(s);
    255     return CacheImpl<A>::NumInputEpsilons(s);
    256   }
    257 
    258   size_t NumOutputEpsilons(StateId s) {
    259     if (!HasArcs(s))
    260       Expand(s);
    261     return CacheImpl<A>::NumOutputEpsilons(s);
    262   }
    263 
    264   uint64 Properties() const { return Properties(kFstProperties); }
    265 
    266   // Set error if found; return FST impl properties.
    267   uint64 Properties(uint64 mask) const {
    268     if ((mask & kError) && fst_->Properties(kError, false))
    269       SetProperties(kError, kError);
    270     return FstImpl<Arc>::Properties(mask);
    271   }
    272 
    273   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    274     if (!HasArcs(s))
    275       Expand(s);
    276     CacheImpl<A>::InitArcIterator(s, data);
    277   }
    278 
    279 
    280   // Find state corresponding to an element. Create new state
    281   // if element not found.
    282   StateId FindState(const Element &e) {
    283     if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
    284       while (unfactored_.size() <= e.state)
    285         unfactored_.push_back(kNoStateId);
    286       if (unfactored_[e.state] == kNoStateId) {
    287         unfactored_[e.state] = elements_.size();
    288         elements_.push_back(e);
    289       }
    290       return unfactored_[e.state];
    291     } else {
    292       typename ElementMap::iterator eit = element_map_.find(e);
    293       if (eit != element_map_.end()) {
    294         return (*eit).second;
    295       } else {
    296         StateId s = elements_.size();
    297         elements_.push_back(e);
    298         element_map_.insert(pair<const Element, StateId>(e, s));
    299         return s;
    300       }
    301     }
    302   }
    303 
    304   // Computes the outgoing transitions from a state, creating new destination
    305   // states as needed.
    306   void Expand(StateId s) {
    307     Element e = elements_[s];
    308     if (e.state != kNoStateId) {
    309       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    310            !ait.Done();
    311            ait.Next()) {
    312         const A &arc = ait.Value();
    313         Weight w = Times(e.weight, arc.weight);
    314         FactorIterator fit(w);
    315         if (!(mode_ & kFactorArcWeights) || fit.Done()) {
    316           StateId d = FindState(Element(arc.nextstate, Weight::One()));
    317           PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
    318         } else {
    319           for (; !fit.Done(); fit.Next()) {
    320             const pair<Weight, Weight> &p = fit.Value();
    321             StateId d = FindState(Element(arc.nextstate,
    322                                           p.second.Quantize(delta_)));
    323             PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
    324           }
    325         }
    326       }
    327     }
    328 
    329     if ((mode_ & kFactorFinalWeights) &&
    330         ((e.state == kNoStateId) ||
    331          (fst_->Final(e.state) != Weight::Zero()))) {
    332       Weight w = e.state == kNoStateId
    333                  ? e.weight
    334                  : Times(e.weight, fst_->Final(e.state));
    335       for (FactorIterator fit(w);
    336            !fit.Done();
    337            fit.Next()) {
    338         const pair<Weight, Weight> &p = fit.Value();
    339         StateId d = FindState(Element(kNoStateId,
    340                                       p.second.Quantize(delta_)));
    341         PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
    342       }
    343     }
    344     SetArcs(s);
    345   }
    346 
    347  private:
    348   static const size_t kPrime = 7853;
    349 
    350   // Equality function for Elements, assume weights have been quantized.
    351   class ElementEqual {
    352    public:
    353     bool operator()(const Element &x, const Element &y) const {
    354       return x.state == y.state && x.weight == y.weight;
    355     }
    356   };
    357 
    358   // Hash function for Elements to Fst states.
    359   class ElementKey {
    360    public:
    361     size_t operator()(const Element &x) const {
    362       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
    363     }
    364    private:
    365   };
    366 
    367   typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    368 
    369   const Fst<A> *fst_;
    370   float delta_;
    371   uint32 mode_;               // factoring arc and/or final weights
    372   Label final_ilabel_;        // ilabel of arc created when factoring final w's
    373   Label final_olabel_;        // olabel of arc created when factoring final w's
    374   vector<Element> elements_;  // mapping Fst state to Elements
    375   ElementMap element_map_;    // mapping Elements to Fst state
    376   // mapping between old/new 'StateId' for states that do not need to
    377   // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
    378   vector<StateId> unfactored_;
    379 
    380   void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
    381 };
    382 
    383 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
    384 
    385 
    386 // FactorWeightFst takes as template parameter a FactorIterator as
    387 // defined above. The result of weight factoring is a transducer
    388 // equivalent to the input whose path weights have been factored
    389 // according to the FactorIterator. States and transitions will be
    390 // added as necessary. The algorithm is a generalization to arbitrary
    391 // weights of the second step of the input epsilon-normalization
    392 // algorithm due to Mohri, "Generic epsilon-removal and input
    393 // epsilon-normalization algorithms for weighted transducers",
    394 // International Journal of Computer Science 13(1): 129-143 (2002).
    395 //
    396 // This class attaches interface to implementation and handles
    397 // reference counting, delegating most methods to ImplToFst.
    398 template <class A, class F>
    399 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
    400  public:
    401   friend class ArcIterator< FactorWeightFst<A, F> >;
    402   friend class StateIterator< FactorWeightFst<A, F> >;
    403 
    404   typedef A Arc;
    405   typedef typename A::Weight Weight;
    406   typedef typename A::StateId StateId;
    407   typedef CacheState<A> State;
    408   typedef FactorWeightFstImpl<A, F> Impl;
    409 
    410   FactorWeightFst(const Fst<A> &fst)
    411       : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
    412 
    413   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
    414       : ImplToFst<Impl>(new Impl(fst, opts)) {}
    415 
    416   // See Fst<>::Copy() for doc.
    417   FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
    418       : ImplToFst<Impl>(fst, copy) {}
    419 
    420   // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
    421   virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
    422     return new FactorWeightFst<A, F>(*this, copy);
    423   }
    424 
    425   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    426 
    427   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    428     GetImpl()->InitArcIterator(s, data);
    429   }
    430 
    431  private:
    432   // Makes visible to friends.
    433   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    434 
    435   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
    436 };
    437 
    438 
    439 // Specialization for FactorWeightFst.
    440 template<class A, class F>
    441 class StateIterator< FactorWeightFst<A, F> >
    442     : public CacheStateIterator< FactorWeightFst<A, F> > {
    443  public:
    444   explicit StateIterator(const FactorWeightFst<A, F> &fst)
    445       : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
    446 };
    447 
    448 
    449 // Specialization for FactorWeightFst.
    450 template <class A, class F>
    451 class ArcIterator< FactorWeightFst<A, F> >
    452     : public CacheArcIterator< FactorWeightFst<A, F> > {
    453  public:
    454   typedef typename A::StateId StateId;
    455 
    456   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
    457       : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
    458     if (!fst.GetImpl()->HasArcs(s))
    459       fst.GetImpl()->Expand(s);
    460   }
    461 
    462  private:
    463   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    464 };
    465 
    466 template <class A, class F> inline
    467 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
    468 {
    469   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
    470 }
    471 
    472 
    473 }  // namespace fst
    474 
    475 #endif // FST_LIB_FACTOR_WEIGHT_H__
    476