Home | History | Annotate | Download | only in fst
      1 // map.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Class to map over/transform states e.g., sort transitions
     20 // Consider using when operation does not change the number of states.
     21 
     22 #ifndef FST_LIB_STATE_MAP_H__
     23 #define FST_LIB_STATE_MAP_H__
     24 
     25 #include <algorithm>
     26 #include <tr1/unordered_map>
     27 using std::tr1::unordered_map;
     28 using std::tr1::unordered_multimap;
     29 #include <string>
     30 #include <utility>
     31 using std::pair; using std::make_pair;
     32 
     33 #include <fst/cache.h>
     34 #include <fst/arc-map.h>
     35 #include <fst/mutable-fst.h>
     36 
     37 
     38 namespace fst {
     39 
     40 // StateMapper Interface - class determinies how states are mapped.
     41 // Useful for implementing operations that do not change the number of states.
     42 //
     43 // class StateMapper {
     44 //  public:
     45 //   typedef A FromArc;
     46 //   typedef B ToArc;
     47 //
     48 //   // Typical constructor
     49 //   StateMapper(const Fst<A> &fst);
     50 //   // Required copy constructor that allows updating Fst argument;
     51 //   // pass only if relevant and changed.
     52 //   StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0);
     53 //
     54 //   // Specifies initial state of result
     55 //   B::StateId Start() const;
     56 //   // Specifies state's final weight in result
     57 //   B::Weight Final(B::StateId s) const;
     58 //
     59 //   // These methods iterate through a state's arcs in result
     60 //   // Specifies state to iterate over
     61 //   void SetState(B::StateId s);
     62 //   // End of arcs?
     63 //   bool Done() const;
     64 //   // Current arc
     65 
     66 //   const B &Value() const;
     67 //   // Advance to next arc (when !Done)
     68 //   void Next();
     69 //
     70 //   // Specifies input symbol table action the mapper requires (see above).
     71 //   MapSymbolsAction InputSymbolsAction() const;
     72 //   // Specifies output symbol table action the mapper requires (see above).
     73 //   MapSymbolsAction OutputSymbolsAction() const;
     74 //   // This specifies the known properties of an Fst mapped by this
     75 //   // mapper. It takes as argument the input Fst's known properties.
     76 //   uint64 Properties(uint64 props) const;
     77 // };
     78 //
     79 // We include a various state map versions below. One dimension of
     80 // variation is whether the mapping mutates its input, writes to a
     81 // new result Fst, or is an on-the-fly Fst. Another dimension is how
     82 // we pass the mapper. We allow passing the mapper by pointer
     83 // for cases that we need to change the state of the user's mapper.
     84 // We also include map versions that pass the mapper
     85 // by value or const reference when this suffices.
     86 
     87 // Maps an arc type A using a mapper function object C, passed
     88 // by pointer.  This version modifies its Fst input.
     89 template<class A, class C>
     90 void StateMap(MutableFst<A> *fst, C* mapper) {
     91   typedef typename A::StateId StateId;
     92   typedef typename A::Weight Weight;
     93 
     94   if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
     95     fst->SetInputSymbols(0);
     96 
     97   if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
     98     fst->SetOutputSymbols(0);
     99 
    100   if (fst->Start() == kNoStateId)
    101     return;
    102 
    103   uint64 props = fst->Properties(kFstProperties, false);
    104 
    105   fst->SetStart(mapper->Start());
    106 
    107   for (StateId s = 0; s < fst->NumStates(); ++s) {
    108     mapper->SetState(s);
    109     fst->DeleteArcs(s);
    110     for (; !mapper->Done(); mapper->Next())
    111       fst->AddArc(s, mapper->Value());
    112     fst->SetFinal(s, mapper->Final(s));
    113   }
    114 
    115   fst->SetProperties(mapper->Properties(props), kFstProperties);
    116 }
    117 
    118 // Maps an arc type A using a mapper function object C, passed
    119 // by value.  This version modifies its Fst input.
    120 template<class A, class C>
    121 void StateMap(MutableFst<A> *fst, C mapper) {
    122   StateMap(fst, &mapper);
    123 }
    124 
    125 
    126 // Maps an arc type A to an arc type B using mapper function
    127 // object C, passed by pointer. This version writes the mapped
    128 // input Fst to an output MutableFst.
    129 template<class A, class B, class C>
    130 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) {
    131   typedef typename A::StateId StateId;
    132   typedef typename A::Weight Weight;
    133 
    134   ofst->DeleteStates();
    135 
    136   if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS)
    137     ofst->SetInputSymbols(ifst.InputSymbols());
    138   else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
    139     ofst->SetInputSymbols(0);
    140 
    141   if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
    142     ofst->SetOutputSymbols(ifst.OutputSymbols());
    143   else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
    144     ofst->SetOutputSymbols(0);
    145 
    146   uint64 iprops = ifst.Properties(kCopyProperties, false);
    147 
    148   if (ifst.Start() == kNoStateId) {
    149     if (iprops & kError) ofst->SetProperties(kError, kError);
    150     return;
    151   }
    152 
    153   // Add all states.
    154   if (ifst.Properties(kExpanded, false))
    155     ofst->ReserveStates(CountStates(ifst));
    156   for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next())
    157     ofst->AddState();
    158 
    159   ofst->SetStart(mapper->Start());
    160 
    161   for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) {
    162     StateId s = siter.Value();
    163     mapper->SetState(s);
    164     for (; !mapper->Done(); mapper->Next())
    165       ofst->AddArc(s, mapper->Value());
    166     ofst->SetFinal(s, mapper->Final(s));
    167   }
    168 
    169   uint64 oprops = ofst->Properties(kFstProperties, false);
    170   ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
    171 }
    172 
    173 // Maps an arc type A to an arc type B using mapper function
    174 // object C, passed by value. This version writes the mapped input
    175 // Fst to an output MutableFst.
    176 template<class A, class B, class C>
    177 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
    178   StateMap(ifst, ofst, &mapper);
    179 }
    180 
    181 typedef CacheOptions StateMapFstOptions;
    182 
    183 template <class A, class B, class C> class StateMapFst;
    184 
    185 // Implementation of delayed StateMapFst.
    186 template <class A, class B, class C>
    187 class StateMapFstImpl : public CacheImpl<B> {
    188  public:
    189   using FstImpl<B>::SetType;
    190   using FstImpl<B>::SetProperties;
    191   using FstImpl<B>::SetInputSymbols;
    192   using FstImpl<B>::SetOutputSymbols;
    193 
    194   using CacheImpl<B>::PushArc;
    195   using CacheImpl<B>::HasArcs;
    196   using CacheImpl<B>::HasFinal;
    197   using CacheImpl<B>::HasStart;
    198   using CacheImpl<B>::SetArcs;
    199   using CacheImpl<B>::SetFinal;
    200   using CacheImpl<B>::SetStart;
    201 
    202   friend class StateIterator< StateMapFst<A, B, C> >;
    203 
    204   typedef B Arc;
    205   typedef typename B::Weight Weight;
    206   typedef typename B::StateId StateId;
    207 
    208   StateMapFstImpl(const Fst<A> &fst, const C &mapper,
    209                  const StateMapFstOptions& opts)
    210       : CacheImpl<B>(opts),
    211         fst_(fst.Copy()),
    212         mapper_(new C(mapper, fst_)),
    213         own_mapper_(true) {
    214     Init();
    215   }
    216 
    217   StateMapFstImpl(const Fst<A> &fst, C *mapper,
    218                  const StateMapFstOptions& opts)
    219       : CacheImpl<B>(opts),
    220         fst_(fst.Copy()),
    221         mapper_(mapper),
    222         own_mapper_(false) {
    223     Init();
    224   }
    225 
    226   StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
    227       : CacheImpl<B>(impl),
    228         fst_(impl.fst_->Copy(true)),
    229         mapper_(new C(*impl.mapper_, fst_)),
    230         own_mapper_(true) {
    231     Init();
    232   }
    233 
    234   ~StateMapFstImpl() {
    235     delete fst_;
    236     if (own_mapper_) delete mapper_;
    237   }
    238 
    239   StateId Start() {
    240     if (!HasStart())
    241       SetStart(mapper_->Start());
    242     return CacheImpl<B>::Start();
    243   }
    244 
    245   Weight Final(StateId s) {
    246     if (!HasFinal(s))
    247       SetFinal(s, mapper_->Final(s));
    248     return CacheImpl<B>::Final(s);
    249   }
    250 
    251   size_t NumArcs(StateId s) {
    252     if (!HasArcs(s))
    253       Expand(s);
    254     return CacheImpl<B>::NumArcs(s);
    255   }
    256 
    257   size_t NumInputEpsilons(StateId s) {
    258     if (!HasArcs(s))
    259       Expand(s);
    260     return CacheImpl<B>::NumInputEpsilons(s);
    261   }
    262 
    263   size_t NumOutputEpsilons(StateId s) {
    264     if (!HasArcs(s))
    265       Expand(s);
    266     return CacheImpl<B>::NumOutputEpsilons(s);
    267   }
    268 
    269   void InitStateIterator(StateIteratorData<A> *data) const {
    270     fst_->InitStateIterator(data);
    271   }
    272 
    273   void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
    274     if (!HasArcs(s))
    275       Expand(s);
    276     CacheImpl<B>::InitArcIterator(s, data);
    277   }
    278 
    279   uint64 Properties() const { return Properties(kFstProperties); }
    280 
    281   // Set error if found; return FST impl properties.
    282   uint64 Properties(uint64 mask) const {
    283     if ((mask & kError) && (fst_->Properties(kError, false) ||
    284                             (mapper_->Properties(0) & kError)))
    285       SetProperties(kError, kError);
    286     return FstImpl<Arc>::Properties(mask);
    287   }
    288 
    289   void Expand(StateId s) {
    290     // Add exiting arcs.
    291     for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next())
    292       PushArc(s, mapper_->Value());
    293     SetArcs(s);
    294   }
    295 
    296   const Fst<A> &GetFst() const {
    297     return *fst_;
    298   }
    299 
    300  private:
    301   void Init() {
    302     SetType("statemap");
    303 
    304     if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS)
    305       SetInputSymbols(fst_->InputSymbols());
    306     else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
    307       SetInputSymbols(0);
    308 
    309     if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
    310       SetOutputSymbols(fst_->OutputSymbols());
    311     else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
    312       SetOutputSymbols(0);
    313 
    314     uint64 props = fst_->Properties(kCopyProperties, false);
    315     SetProperties(mapper_->Properties(props));
    316   }
    317 
    318   const Fst<A> *fst_;
    319   C*  mapper_;
    320   bool own_mapper_;
    321 
    322   void operator=(const StateMapFstImpl<A, B, C> &);  // disallow
    323 };
    324 
    325 
    326 // Maps an arc type A to an arc type B using Mapper function object
    327 // C. This version is a delayed Fst.
    328 template <class A, class B, class C>
    329 class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > {
    330  public:
    331   friend class ArcIterator< StateMapFst<A, B, C> >;
    332 
    333   typedef B Arc;
    334   typedef typename B::Weight Weight;
    335   typedef typename B::StateId StateId;
    336   typedef CacheState<B> State;
    337   typedef StateMapFstImpl<A, B, C> Impl;
    338 
    339   StateMapFst(const Fst<A> &fst, const C &mapper,
    340               const StateMapFstOptions& opts)
    341       : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
    342 
    343   StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts)
    344       : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
    345 
    346   StateMapFst(const Fst<A> &fst, const C &mapper)
    347       : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
    348 
    349   StateMapFst(const Fst<A> &fst, C* mapper)
    350       : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
    351 
    352   // See Fst<>::Copy() for doc.
    353   StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false)
    354     : ImplToFst<Impl>(fst, safe) {}
    355 
    356   // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
    357   virtual StateMapFst<A, B, C> *Copy(bool safe = false) const {
    358     return new StateMapFst<A, B, C>(*this, safe);
    359   }
    360 
    361   virtual void InitStateIterator(StateIteratorData<A> *data) const {
    362     GetImpl()->InitStateIterator(data);
    363   }
    364 
    365   virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
    366     GetImpl()->InitArcIterator(s, data);
    367   }
    368 
    369  protected:
    370   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    371 
    372  private:
    373   void operator=(const StateMapFst<A, B, C> &fst);  // disallow
    374 };
    375 
    376 
    377 // Specialization for StateMapFst.
    378 template <class A, class B, class C>
    379 class ArcIterator< StateMapFst<A, B, C> >
    380     : public CacheArcIterator< StateMapFst<A, B, C> > {
    381  public:
    382   typedef typename A::StateId StateId;
    383 
    384   ArcIterator(const StateMapFst<A, B, C> &fst, StateId s)
    385       : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) {
    386     if (!fst.GetImpl()->HasArcs(s))
    387       fst.GetImpl()->Expand(s);
    388   }
    389 
    390  private:
    391   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    392 };
    393 
    394 //
    395 // Utility Mappers
    396 //
    397 
    398 // Mapper that returns its input.
    399 template <class A>
    400 class IdentityStateMapper {
    401  public:
    402   typedef A FromArc;
    403   typedef A ToArc;
    404 
    405   typedef typename A::StateId StateId;
    406   typedef typename A::Weight Weight;
    407 
    408   explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {}
    409 
    410   // Allows updating Fst argument; pass only if changed.
    411   IdentityStateMapper(const IdentityStateMapper<A> &mapper,
    412                       const Fst<A> *fst = 0)
    413       : fst_(fst ? *fst : mapper.fst_), aiter_(0) {}
    414 
    415   ~IdentityStateMapper() { delete aiter_; }
    416 
    417   StateId Start() const { return fst_.Start(); }
    418 
    419   Weight Final(StateId s) const { return fst_.Final(s); }
    420 
    421   void SetState(StateId s) {
    422     if (aiter_) delete aiter_;
    423     aiter_ = new ArcIterator< Fst<A> >(fst_, s);
    424   }
    425 
    426   bool Done() const { return aiter_->Done(); }
    427   const A &Value() const { return aiter_->Value(); }
    428   void Next() { aiter_->Next(); }
    429 
    430   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
    431   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
    432 
    433   uint64 Properties(uint64 props) const { return props; }
    434 
    435  private:
    436   const Fst<A> &fst_;
    437   ArcIterator< Fst<A> > *aiter_;
    438 };
    439 
    440 template <class A>
    441 class ArcSumMapper {
    442  public:
    443   typedef A FromArc;
    444   typedef A ToArc;
    445 
    446   typedef typename A::StateId StateId;
    447   typedef typename A::Weight Weight;
    448 
    449   explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
    450 
    451   // Allows updating Fst argument; pass only if changed.
    452   ArcSumMapper(const ArcSumMapper<A> &mapper,
    453                const Fst<A> *fst = 0)
    454       : fst_(fst ? *fst : mapper.fst_), i_(0) {}
    455 
    456   StateId Start() const { return fst_.Start(); }
    457   Weight Final(StateId s) const { return fst_.Final(s); }
    458 
    459   void SetState(StateId s) {
    460     i_ = 0;
    461     arcs_.clear();
    462     arcs_.reserve(fst_.NumArcs(s));
    463     for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
    464       arcs_.push_back(aiter.Value());
    465 
    466     // First sorts the exiting arcs by input label, output label
    467     // and destination state and then sums weights of arcs with
    468     // the same input label, output label, and destination state.
    469     sort(arcs_.begin(), arcs_.end(), comp_);
    470     size_t narcs = 0;
    471     for (size_t i = 0; i < arcs_.size(); ++i) {
    472       if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) {
    473         arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight,
    474                                        arcs_[i].weight);
    475       } else {
    476         arcs_[narcs++] = arcs_[i];
    477       }
    478     }
    479     arcs_.resize(narcs);
    480   }
    481 
    482   bool Done() const { return i_ >= arcs_.size(); }
    483   const A &Value() const { return arcs_[i_]; }
    484   void Next() { ++i_; }
    485 
    486   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
    487   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
    488 
    489   uint64 Properties(uint64 props) const {
    490     return props & kArcSortProperties &
    491         kDeleteArcsProperties & kWeightInvariantProperties;
    492   }
    493 
    494  private:
    495   struct Compare {
    496     bool operator()(const A& x, const A& y) {
    497       if (x.ilabel < y.ilabel) return true;
    498       if (x.ilabel > y.ilabel) return false;
    499       if (x.olabel < y.olabel) return true;
    500       if (x.olabel > y.olabel) return false;
    501       if (x.nextstate < y.nextstate) return true;
    502       if (x.nextstate > y.nextstate) return false;
    503       return false;
    504     }
    505   };
    506 
    507   struct Equal {
    508     bool operator()(const A& x, const A& y) {
    509       return (x.ilabel == y.ilabel &&
    510               x.olabel == y.olabel &&
    511               x.nextstate == y.nextstate);
    512     }
    513   };
    514 
    515   const Fst<A> &fst_;
    516   Compare comp_;
    517   Equal equal_;
    518   vector<A> arcs_;
    519   ssize_t i_;               // current arc position
    520 
    521   void operator=(const ArcSumMapper<A> &);  // disallow
    522 };
    523 
    524 template <class A>
    525 class ArcUniqueMapper {
    526  public:
    527   typedef A FromArc;
    528   typedef A ToArc;
    529 
    530   typedef typename A::StateId StateId;
    531   typedef typename A::Weight Weight;
    532 
    533   explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
    534 
    535   // Allows updating Fst argument; pass only if changed.
    536   ArcUniqueMapper(const ArcUniqueMapper<A> &mapper,
    537                   const Fst<A> *fst = 0)
    538       : fst_(fst ? *fst : mapper.fst_), i_(0) {}
    539 
    540   StateId Start() const { return fst_.Start(); }
    541   Weight Final(StateId s) const { return fst_.Final(s); }
    542 
    543   void SetState(StateId s) {
    544     i_ = 0;
    545     arcs_.clear();
    546     arcs_.reserve(fst_.NumArcs(s));
    547     for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
    548       arcs_.push_back(aiter.Value());
    549 
    550     // First sorts the exiting arcs by input label, output label
    551     // and destination state and then uniques identical arcs
    552     sort(arcs_.begin(), arcs_.end(), comp_);
    553     typename vector<A>::iterator unique_end =
    554         unique(arcs_.begin(), arcs_.end(), equal_);
    555     arcs_.resize(unique_end - arcs_.begin());
    556   }
    557 
    558   bool Done() const { return i_ >= arcs_.size(); }
    559   const A &Value() const { return arcs_[i_]; }
    560   void Next() { ++i_; }
    561 
    562   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
    563   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
    564 
    565   uint64 Properties(uint64 props) const {
    566     return props & kArcSortProperties & kDeleteArcsProperties;
    567   }
    568 
    569  private:
    570   struct Compare {
    571     bool operator()(const A& x, const A& y) {
    572       if (x.ilabel < y.ilabel) return true;
    573       if (x.ilabel > y.ilabel) return false;
    574       if (x.olabel < y.olabel) return true;
    575       if (x.olabel > y.olabel) return false;
    576       if (x.nextstate < y.nextstate) return true;
    577       if (x.nextstate > y.nextstate) return false;
    578       return false;
    579     }
    580   };
    581 
    582   struct Equal {
    583     bool operator()(const A& x, const A& y) {
    584       return (x.ilabel == y.ilabel &&
    585               x.olabel == y.olabel &&
    586               x.nextstate == y.nextstate &&
    587               x.weight == y.weight);
    588     }
    589   };
    590 
    591   const Fst<A> &fst_;
    592   Compare comp_;
    593   Equal equal_;
    594   vector<A> arcs_;
    595   ssize_t i_;               // current arc position
    596 
    597   void operator=(const ArcUniqueMapper<A> &);  // disallow
    598 };
    599 
    600 
    601 }  // namespace fst
    602 
    603 #endif  // FST_LIB_STATE_MAP_H__
    604