Home | History | Annotate | Download | only in fst
      1 // factor-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Classes to factor weights in an FST.
     20 
     21 #ifndef FST_LIB_FACTOR_WEIGHT_H__
     22 #define FST_LIB_FACTOR_WEIGHT_H__
     23 
     24 #include <algorithm>
     25 #include <unordered_map>
     26 using std::tr1::unordered_map;
     27 using std::tr1::unordered_multimap;
     28 #include <fst/slist.h>
     29 #include <string>
     30 #include <utility>
     31 using std::pair; using std::make_pair;
     32 #include <vector>
     33 using std::vector;
     34 
     35 #include <fst/cache.h>
     36 #include <fst/test-properties.h>
     37 
     38 
     39 namespace fst {
     40 
     41 const uint32 kFactorFinalWeights = 0x00000001;
     42 const uint32 kFactorArcWeights   = 0x00000002;
     43 
     44 template <class Arc>
     45 struct FactorWeightOptions : CacheOptions {
     46   typedef typename Arc::Label Label;
     47   float delta;
     48   uint32 mode;         // factor arc weights and/or final weights
     49   Label final_ilabel;  // input label of arc created when factoring final w's
     50   Label final_olabel;  // output label of arc created when factoring final w's
     51 
     52   FactorWeightOptions(const CacheOptions &opts, float d,
     53                       uint32 m = kFactorArcWeights | kFactorFinalWeights,
     54                       Label il = 0, Label ol = 0)
     55       : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
     56         final_olabel(ol) {}
     57 
     58   explicit FactorWeightOptions(
     59       float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
     60       Label il = 0, Label ol = 0)
     61       : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
     62 
     63   FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
     64                       Label il = 0, Label ol = 0)
     65       : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
     66 };
     67 
     68 
     69 // A factor iterator takes as argument a weight w and returns a
     70 // sequence of pairs of weights (xi,yi) such that the sum of the
     71 // products xi times yi is equal to w. If w is fully factored,
     72 // the iterator should return nothing.
     73 //
     74 // template <class W>
     75 // class FactorIterator {
     76 //  public:
     77 //   FactorIterator(W w);
     78 //   bool Done() const;
     79 //   void Next();
     80 //   pair<W, W> Value() const;
     81 //   void Reset();
     82 // }
     83 
     84 
     85 // Factor trivially.
     86 template <class W>
     87 class IdentityFactor {
     88  public:
     89   IdentityFactor(const W &w) {}
     90   bool Done() const { return true; }
     91   void Next() {}
     92   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
     93   void Reset() {}
     94 };
     95 
     96 
     97 // Factor a StringWeight w as 'ab' where 'a' is a label.
     98 template <typename L, StringType S = STRING_LEFT>
     99 class StringFactor {
    100  public:
    101   StringFactor(const StringWeight<L, S> &w)
    102       : weight_(w), done_(w.Size() <= 1) {}
    103 
    104   bool Done() const { return done_; }
    105 
    106   void Next() { done_ = true; }
    107 
    108   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
    109     StringWeightIterator<L, S> iter(weight_);
    110     StringWeight<L, S> w1(iter.Value());
    111     StringWeight<L, S> w2;
    112     for (iter.Next(); !iter.Done(); iter.Next())
    113       w2.PushBack(iter.Value());
    114     return make_pair(w1, w2);
    115   }
    116 
    117   void Reset() { done_ = weight_.Size() <= 1; }
    118 
    119  private:
    120   StringWeight<L, S> weight_;
    121   bool done_;
    122 };
    123 
    124 
    125 // Factor a GallicWeight using StringFactor.
    126 template <class L, class W, StringType S = STRING_LEFT>
    127 class GallicFactor {
    128  public:
    129   GallicFactor(const GallicWeight<L, W, S> &w)
    130       : weight_(w), done_(w.Value1().Size() <= 1) {}
    131 
    132   bool Done() const { return done_; }
    133 
    134   void Next() { done_ = true; }
    135 
    136   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
    137     StringFactor<L, S> iter(weight_.Value1());
    138     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
    139     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
    140     return make_pair(w1, w2);
    141   }
    142 
    143   void Reset() { done_ = weight_.Value1().Size() <= 1; }
    144 
    145  private:
    146   GallicWeight<L, W, S> weight_;
    147   bool done_;
    148 };
    149 
    150 
    151 // Implementation class for FactorWeight
    152 template <class A, class F>
    153 class FactorWeightFstImpl
    154     : public CacheImpl<A> {
    155  public:
    156   using FstImpl<A>::SetType;
    157   using FstImpl<A>::SetProperties;
    158   using FstImpl<A>::SetInputSymbols;
    159   using FstImpl<A>::SetOutputSymbols;
    160 
    161   using CacheBaseImpl< CacheState<A> >::PushArc;
    162   using CacheBaseImpl< CacheState<A> >::HasStart;
    163   using CacheBaseImpl< CacheState<A> >::HasFinal;
    164   using CacheBaseImpl< CacheState<A> >::HasArcs;
    165   using CacheBaseImpl< CacheState<A> >::SetArcs;
    166   using CacheBaseImpl< CacheState<A> >::SetFinal;
    167   using CacheBaseImpl< CacheState<A> >::SetStart;
    168 
    169   typedef A Arc;
    170   typedef typename A::Label Label;
    171   typedef typename A::Weight Weight;
    172   typedef typename A::StateId StateId;
    173   typedef F FactorIterator;
    174 
    175   struct Element {
    176     Element() {}
    177 
    178     Element(StateId s, Weight w) : state(s), weight(w) {}
    179 
    180     StateId state;     // Input state Id
    181     Weight weight;     // Residual weight
    182   };
    183 
    184   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
    185       : CacheImpl<A>(opts),
    186         fst_(fst.Copy()),
    187         delta_(opts.delta),
    188         mode_(opts.mode),
    189         final_ilabel_(opts.final_ilabel),
    190         final_olabel_(opts.final_olabel) {
    191     SetType("factor_weight");
    192     uint64 props = fst.Properties(kFstProperties, false);
    193     SetProperties(FactorWeightProperties(props), kCopyProperties);
    194 
    195     SetInputSymbols(fst.InputSymbols());
    196     SetOutputSymbols(fst.OutputSymbols());
    197 
    198     if (mode_ == 0)
    199       LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
    200                    << "factoring neither arc weights nor final weights.";
    201   }
    202 
    203   FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
    204       : CacheImpl<A>(impl),
    205         fst_(impl.fst_->Copy(true)),
    206         delta_(impl.delta_),
    207         mode_(impl.mode_),
    208         final_ilabel_(impl.final_ilabel_),
    209         final_olabel_(impl.final_olabel_) {
    210     SetType("factor_weight");
    211     SetProperties(impl.Properties(), kCopyProperties);
    212     SetInputSymbols(impl.InputSymbols());
    213     SetOutputSymbols(impl.OutputSymbols());
    214   }
    215 
    216   ~FactorWeightFstImpl() {
    217     delete fst_;
    218   }
    219 
    220   StateId Start() {
    221     if (!HasStart()) {
    222       StateId s = fst_->Start();
    223       if (s == kNoStateId)
    224         return kNoStateId;
    225       StateId start = FindState(Element(fst_->Start(), Weight::One()));
    226       SetStart(start);
    227     }
    228     return CacheImpl<A>::Start();
    229   }
    230 
    231   Weight Final(StateId s) {
    232     if (!HasFinal(s)) {
    233       const Element &e = elements_[s];
    234       // TODO: fix so cast is unnecessary
    235       Weight w = e.state == kNoStateId
    236                  ? e.weight
    237                  : (Weight) Times(e.weight, fst_->Final(e.state));
    238       FactorIterator f(w);
    239       if (!(mode_ & kFactorFinalWeights) || f.Done())
    240         SetFinal(s, w);
    241       else
    242         SetFinal(s, Weight::Zero());
    243     }
    244     return CacheImpl<A>::Final(s);
    245   }
    246 
    247   size_t NumArcs(StateId s) {
    248     if (!HasArcs(s))
    249       Expand(s);
    250     return CacheImpl<A>::NumArcs(s);
    251   }
    252 
    253   size_t NumInputEpsilons(StateId s) {
    254     if (!HasArcs(s))
    255       Expand(s);
    256     return CacheImpl<A>::NumInputEpsilons(s);
    257   }
    258 
    259   size_t NumOutputEpsilons(StateId s) {
    260     if (!HasArcs(s))
    261       Expand(s);
    262     return CacheImpl<A>::NumOutputEpsilons(s);
    263   }
    264 
    265   uint64 Properties() const { return Properties(kFstProperties); }
    266 
    267   // Set error if found; return FST impl properties.
    268   uint64 Properties(uint64 mask) const {
    269     if ((mask & kError) && fst_->Properties(kError, false))
    270       SetProperties(kError, kError);
    271     return FstImpl<Arc>::Properties(mask);
    272   }
    273 
    274   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    275     if (!HasArcs(s))
    276       Expand(s);
    277     CacheImpl<A>::InitArcIterator(s, data);
    278   }
    279 
    280 
    281   // Find state corresponding to an element. Create new state
    282   // if element not found.
    283   StateId FindState(const Element &e) {
    284     if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
    285       while (unfactored_.size() <= e.state)
    286         unfactored_.push_back(kNoStateId);
    287       if (unfactored_[e.state] == kNoStateId) {
    288         unfactored_[e.state] = elements_.size();
    289         elements_.push_back(e);
    290       }
    291       return unfactored_[e.state];
    292     } else {
    293       typename ElementMap::iterator eit = element_map_.find(e);
    294       if (eit != element_map_.end()) {
    295         return (*eit).second;
    296       } else {
    297         StateId s = elements_.size();
    298         elements_.push_back(e);
    299         element_map_.insert(pair<const Element, StateId>(e, s));
    300         return s;
    301       }
    302     }
    303   }
    304 
    305   // Computes the outgoing transitions from a state, creating new destination
    306   // states as needed.
    307   void Expand(StateId s) {
    308     Element e = elements_[s];
    309     if (e.state != kNoStateId) {
    310       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    311            !ait.Done();
    312            ait.Next()) {
    313         const A &arc = ait.Value();
    314         Weight w = Times(e.weight, arc.weight);
    315         FactorIterator fit(w);
    316         if (!(mode_ & kFactorArcWeights) || fit.Done()) {
    317           StateId d = FindState(Element(arc.nextstate, Weight::One()));
    318           PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
    319         } else {
    320           for (; !fit.Done(); fit.Next()) {
    321             const pair<Weight, Weight> &p = fit.Value();
    322             StateId d = FindState(Element(arc.nextstate,
    323                                           p.second.Quantize(delta_)));
    324             PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
    325           }
    326         }
    327       }
    328     }
    329 
    330     if ((mode_ & kFactorFinalWeights) &&
    331         ((e.state == kNoStateId) ||
    332          (fst_->Final(e.state) != Weight::Zero()))) {
    333       Weight w = e.state == kNoStateId
    334                  ? e.weight
    335                  : Times(e.weight, fst_->Final(e.state));
    336       for (FactorIterator fit(w);
    337            !fit.Done();
    338            fit.Next()) {
    339         const pair<Weight, Weight> &p = fit.Value();
    340         StateId d = FindState(Element(kNoStateId,
    341                                       p.second.Quantize(delta_)));
    342         PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
    343       }
    344     }
    345     SetArcs(s);
    346   }
    347 
    348  private:
    349   static const size_t kPrime = 7853;
    350 
    351   // Equality function for Elements, assume weights have been quantized.
    352   class ElementEqual {
    353    public:
    354     bool operator()(const Element &x, const Element &y) const {
    355       return x.state == y.state && x.weight == y.weight;
    356     }
    357   };
    358 
    359   // Hash function for Elements to Fst states.
    360   class ElementKey {
    361    public:
    362     size_t operator()(const Element &x) const {
    363       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
    364     }
    365    private:
    366   };
    367 
    368   typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    369 
    370   const Fst<A> *fst_;
    371   float delta_;
    372   uint32 mode_;               // factoring arc and/or final weights
    373   Label final_ilabel_;        // ilabel of arc created when factoring final w's
    374   Label final_olabel_;        // olabel of arc created when factoring final w's
    375   vector<Element> elements_;  // mapping Fst state to Elements
    376   ElementMap element_map_;    // mapping Elements to Fst state
    377   // mapping between old/new 'StateId' for states that do not need to
    378   // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
    379   vector<StateId> unfactored_;
    380 
    381   void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
    382 };
    383 
    384 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
    385 
    386 
    387 // FactorWeightFst takes as template parameter a FactorIterator as
    388 // defined above. The result of weight factoring is a transducer
    389 // equivalent to the input whose path weights have been factored
    390 // according to the FactorIterator. States and transitions will be
    391 // added as necessary. The algorithm is a generalization to arbitrary
    392 // weights of the second step of the input epsilon-normalization
    393 // algorithm due to Mohri, "Generic epsilon-removal and input
    394 // epsilon-normalization algorithms for weighted transducers",
    395 // International Journal of Computer Science 13(1): 129-143 (2002).
    396 //
    397 // This class attaches interface to implementation and handles
    398 // reference counting, delegating most methods to ImplToFst.
    399 template <class A, class F>
    400 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
    401  public:
    402   friend class ArcIterator< FactorWeightFst<A, F> >;
    403   friend class StateIterator< FactorWeightFst<A, F> >;
    404 
    405   typedef A Arc;
    406   typedef typename A::Weight Weight;
    407   typedef typename A::StateId StateId;
    408   typedef CacheState<A> State;
    409   typedef FactorWeightFstImpl<A, F> Impl;
    410 
    411   FactorWeightFst(const Fst<A> &fst)
    412       : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
    413 
    414   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
    415       : ImplToFst<Impl>(new Impl(fst, opts)) {}
    416 
    417   // See Fst<>::Copy() for doc.
    418   FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
    419       : ImplToFst<Impl>(fst, copy) {}
    420 
    421   // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
    422   virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
    423     return new FactorWeightFst<A, F>(*this, copy);
    424   }
    425 
    426   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    427 
    428   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    429     GetImpl()->InitArcIterator(s, data);
    430   }
    431 
    432  private:
    433   // Makes visible to friends.
    434   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    435 
    436   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
    437 };
    438 
    439 
    440 // Specialization for FactorWeightFst.
    441 template<class A, class F>
    442 class StateIterator< FactorWeightFst<A, F> >
    443     : public CacheStateIterator< FactorWeightFst<A, F> > {
    444  public:
    445   explicit StateIterator(const FactorWeightFst<A, F> &fst)
    446       : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
    447 };
    448 
    449 
    450 // Specialization for FactorWeightFst.
    451 template <class A, class F>
    452 class ArcIterator< FactorWeightFst<A, F> >
    453     : public CacheArcIterator< FactorWeightFst<A, F> > {
    454  public:
    455   typedef typename A::StateId StateId;
    456 
    457   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
    458       : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
    459     if (!fst.GetImpl()->HasArcs(s))
    460       fst.GetImpl()->Expand(s);
    461   }
    462 
    463  private:
    464   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    465 };
    466 
    467 template <class A, class F> inline
    468 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
    469 {
    470   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
    471 }
    472 
    473 
    474 }  // namespace fst
    475 
    476 #endif // FST_LIB_FACTOR_WEIGHT_H__
    477