Home | History | Annotate | Download | only in fst
      1 // synchronize.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Synchronize an FST with bounded delay.
     20 
     21 #ifndef FST_LIB_SYNCHRONIZE_H__
     22 #define FST_LIB_SYNCHRONIZE_H__
     23 
     24 #include <algorithm>
     25 #include <tr1/unordered_map>
     26 using std::tr1::unordered_map;
     27 using std::tr1::unordered_multimap;
     28 #include <tr1/unordered_set>
     29 using std::tr1::unordered_set;
     30 using std::tr1::unordered_multiset;
     31 #include <string>
     32 #include <utility>
     33 using std::pair; using std::make_pair;
     34 #include <vector>
     35 using std::vector;
     36 
     37 #include <fst/cache.h>
     38 #include <fst/test-properties.h>
     39 
     40 
     41 namespace fst {
     42 
     43 typedef CacheOptions SynchronizeFstOptions;
     44 
     45 
     46 // Implementation class for SynchronizeFst
     47 template <class A>
     48 class SynchronizeFstImpl
     49     : public CacheImpl<A> {
     50  public:
     51   using FstImpl<A>::SetType;
     52   using FstImpl<A>::SetProperties;
     53   using FstImpl<A>::SetInputSymbols;
     54   using FstImpl<A>::SetOutputSymbols;
     55 
     56   using CacheBaseImpl< CacheState<A> >::PushArc;
     57   using CacheBaseImpl< CacheState<A> >::HasArcs;
     58   using CacheBaseImpl< CacheState<A> >::HasFinal;
     59   using CacheBaseImpl< CacheState<A> >::HasStart;
     60   using CacheBaseImpl< CacheState<A> >::SetArcs;
     61   using CacheBaseImpl< CacheState<A> >::SetFinal;
     62   using CacheBaseImpl< CacheState<A> >::SetStart;
     63 
     64   typedef A Arc;
     65   typedef typename A::Label Label;
     66   typedef typename A::Weight Weight;
     67   typedef typename A::StateId StateId;
     68 
     69   typedef basic_string<Label> String;
     70 
     71   struct Element {
     72     Element() {}
     73 
     74     Element(StateId s, const String *i, const String *o)
     75         : state(s), istring(i), ostring(o) {}
     76 
     77     StateId state;     // Input state Id
     78     const String *istring;     // Residual input labels
     79     const String *ostring;     // Residual output labels
     80     // Residual strings are represented by const pointers to
     81     // basic_string<Label> and are stored in a hash_set. The pointed
     82     // memory is owned by the hash_set string_set_.
     83   };
     84 
     85   SynchronizeFstImpl(const Fst<A> &fst, const SynchronizeFstOptions &opts)
     86       : CacheImpl<A>(opts), fst_(fst.Copy()) {
     87     SetType("synchronize");
     88     uint64 props = fst.Properties(kFstProperties, false);
     89     SetProperties(SynchronizeProperties(props), kCopyProperties);
     90 
     91     SetInputSymbols(fst.InputSymbols());
     92     SetOutputSymbols(fst.OutputSymbols());
     93   }
     94 
     95   SynchronizeFstImpl(const SynchronizeFstImpl &impl)
     96       : CacheImpl<A>(impl),
     97         fst_(impl.fst_->Copy(true)) {
     98     SetType("synchronize");
     99     SetProperties(impl.Properties(), kCopyProperties);
    100     SetInputSymbols(impl.InputSymbols());
    101     SetOutputSymbols(impl.OutputSymbols());
    102   }
    103 
    104   ~SynchronizeFstImpl() {
    105     delete fst_;
    106     // Extract pointers from the hash set
    107     vector<const String*> strings;
    108     typename StringSet::iterator it = string_set_.begin();
    109     for (; it != string_set_.end(); ++it)
    110       strings.push_back(*it);
    111     // Free the extracted pointers
    112     for (size_t i = 0; i < strings.size(); ++i)
    113       delete strings[i];
    114   }
    115 
    116   StateId Start() {
    117     if (!HasStart()) {
    118       StateId s = fst_->Start();
    119       if (s == kNoStateId)
    120         return kNoStateId;
    121       const String *empty = FindString(new String());
    122       StateId start = FindState(Element(fst_->Start(), empty, empty));
    123       SetStart(start);
    124     }
    125     return CacheImpl<A>::Start();
    126   }
    127 
    128   Weight Final(StateId s) {
    129     if (!HasFinal(s)) {
    130       const Element &e = elements_[s];
    131       Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
    132       if ((w != Weight::Zero()) && (e.istring)->empty() && (e.ostring)->empty())
    133         SetFinal(s, w);
    134       else
    135         SetFinal(s, Weight::Zero());
    136     }
    137     return CacheImpl<A>::Final(s);
    138   }
    139 
    140   size_t NumArcs(StateId s) {
    141     if (!HasArcs(s))
    142       Expand(s);
    143     return CacheImpl<A>::NumArcs(s);
    144   }
    145 
    146   size_t NumInputEpsilons(StateId s) {
    147     if (!HasArcs(s))
    148       Expand(s);
    149     return CacheImpl<A>::NumInputEpsilons(s);
    150   }
    151 
    152   size_t NumOutputEpsilons(StateId s) {
    153     if (!HasArcs(s))
    154       Expand(s);
    155     return CacheImpl<A>::NumOutputEpsilons(s);
    156   }
    157 
    158   uint64 Properties() const { return Properties(kFstProperties); }
    159 
    160   // Set error if found; return FST impl properties.
    161   uint64 Properties(uint64 mask) const {
    162     if ((mask & kError) && fst_->Properties(kError, false))
    163       SetProperties(kError, kError);
    164     return FstImpl<Arc>::Properties(mask);
    165   }
    166 
    167   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    168     if (!HasArcs(s))
    169       Expand(s);
    170     CacheImpl<A>::InitArcIterator(s, data);
    171   }
    172 
    173   // Returns the first character of the string obtained by
    174   // concatenating s and l.
    175   Label Car(const String *s, Label l = 0) const {
    176     if (!s->empty())
    177       return (*s)[0];
    178     else
    179       return l;
    180   }
    181 
    182   // Computes the residual string obtained by removing the first
    183   // character in the concatenation of s and l.
    184   const String *Cdr(const String *s, Label l = 0) {
    185     String *r = new String();
    186     for (int i = 1; i < s->size(); ++i)
    187       r->push_back((*s)[i]);
    188     if (l && !(s->empty())) r->push_back(l);
    189     return FindString(r);
    190   }
    191 
    192   // Computes the concatenation of s and l.
    193   const String *Concat(const String *s, Label l = 0) {
    194     String *r = new String();
    195     for (int i = 0; i < s->size(); ++i)
    196       r->push_back((*s)[i]);
    197     if (l) r->push_back(l);
    198     return FindString(r);
    199   }
    200 
    201   // Tests if the concatenation of s and l is empty
    202   bool Empty(const String *s, Label l = 0) const {
    203     if (s->empty())
    204       return l == 0;
    205     else
    206       return false;
    207   }
    208 
    209   // Finds the string pointed by s in the hash set. Transfers the
    210   // pointer ownership to the hash set.
    211   const String *FindString(const String *s) {
    212     typename StringSet::iterator it = string_set_.find(s);
    213     if (it != string_set_.end()) {
    214       delete s;
    215       return (*it);
    216     } else {
    217       string_set_.insert(s);
    218       return s;
    219     }
    220   }
    221 
    222   // Finds state corresponding to an element. Creates new state
    223   // if element not found.
    224   StateId FindState(const Element &e) {
    225     typename ElementMap::iterator eit = element_map_.find(e);
    226     if (eit != element_map_.end()) {
    227       return (*eit).second;
    228     } else {
    229       StateId s = elements_.size();
    230       elements_.push_back(e);
    231       element_map_.insert(pair<const Element, StateId>(e, s));
    232       return s;
    233     }
    234   }
    235 
    236 
    237   // Computes the outgoing transitions from a state, creating new destination
    238   // states as needed.
    239   void Expand(StateId s) {
    240     Element e = elements_[s];
    241 
    242     if (e.state != kNoStateId)
    243       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    244            !ait.Done();
    245            ait.Next()) {
    246         const A &arc = ait.Value();
    247         if (!Empty(e.istring, arc.ilabel)  && !Empty(e.ostring, arc.olabel)) {
    248           const String *istring = Cdr(e.istring, arc.ilabel);
    249           const String *ostring = Cdr(e.ostring, arc.olabel);
    250           StateId d = FindState(Element(arc.nextstate, istring, ostring));
    251           PushArc(s, Arc(Car(e.istring, arc.ilabel),
    252                         Car(e.ostring, arc.olabel), arc.weight, d));
    253         } else {
    254           const String *istring = Concat(e.istring, arc.ilabel);
    255           const String *ostring = Concat(e.ostring, arc.olabel);
    256           StateId d = FindState(Element(arc.nextstate, istring, ostring));
    257           PushArc(s, Arc(0 , 0, arc.weight, d));
    258         }
    259       }
    260 
    261     Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
    262     if ((w != Weight::Zero()) &&
    263         ((e.istring)->size() + (e.ostring)->size() > 0)) {
    264       const String *istring = Cdr(e.istring);
    265       const String *ostring = Cdr(e.ostring);
    266       StateId d = FindState(Element(kNoStateId, istring, ostring));
    267       PushArc(s, Arc(Car(e.istring), Car(e.ostring), w, d));
    268     }
    269     SetArcs(s);
    270   }
    271 
    272  private:
    273   // Equality function for Elements, assume strings have been hashed.
    274   class ElementEqual {
    275    public:
    276     bool operator()(const Element &x, const Element &y) const {
    277       return x.state == y.state &&
    278               x.istring == y.istring &&
    279               x.ostring == y.ostring;
    280     }
    281   };
    282 
    283   // Hash function for Elements to Fst states.
    284   class ElementKey {
    285    public:
    286     size_t operator()(const Element &x) const {
    287       size_t key = x.state;
    288       key = (key << 1) ^ (x.istring)->size();
    289       for (size_t i = 0; i < (x.istring)->size(); ++i)
    290         key = (key << 1) ^ (*x.istring)[i];
    291       key = (key << 1) ^ (x.ostring)->size();
    292       for (size_t i = 0; i < (x.ostring)->size(); ++i)
    293         key = (key << 1) ^ (*x.ostring)[i];
    294       return key;
    295     }
    296   };
    297 
    298   // Equality function for strings
    299   class StringEqual {
    300    public:
    301     bool operator()(const String * const &x, const String * const &y) const {
    302       if (x->size() != y->size()) return false;
    303       for (size_t i = 0; i < x->size(); ++i)
    304         if ((*x)[i] != (*y)[i]) return false;
    305       return true;
    306     }
    307   };
    308 
    309   // Hash function for set of strings
    310   class StringKey{
    311    public:
    312     size_t operator()(const String * const & x) const {
    313       size_t key = x->size();
    314       for (size_t i = 0; i < x->size(); ++i)
    315         key = (key << 1) ^ (*x)[i];
    316       return key;
    317     }
    318   };
    319 
    320 
    321   typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    322   typedef unordered_set<const String*, StringKey, StringEqual> StringSet;
    323 
    324   const Fst<A> *fst_;
    325   vector<Element> elements_;  // mapping Fst state to Elements
    326   ElementMap element_map_;    // mapping Elements to Fst state
    327   StringSet string_set_;
    328 
    329   void operator=(const SynchronizeFstImpl<A> &);  // disallow
    330 };
    331 
    332 
    333 // Synchronizes a transducer. This version is a delayed Fst.  The
    334 // result will be an equivalent FST that has the property that during
    335 // the traversal of a path, the delay is either zero or strictly
    336 // increasing, where the delay is the difference between the number of
    337 // non-epsilon output labels and input labels along the path.
    338 //
    339 // For the algorithm to terminate, the input transducer must have
    340 // bounded delay, i.e., the delay of every cycle must be zero.
    341 //
    342 // Complexity:
    343 // - A has bounded delay: exponential
    344 // - A does not have bounded delay: does not terminate
    345 //
    346 // References:
    347 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
    348 //   Definitions and Algorithms, International Journal of Computer
    349 //   Science, 14(6): 957-982 (2003).
    350 //
    351 // This class attaches interface to implementation and handles
    352 // reference counting, delegating most methods to ImplToFst.
    353 template <class A>
    354 class SynchronizeFst : public ImplToFst< SynchronizeFstImpl<A> > {
    355  public:
    356   friend class ArcIterator< SynchronizeFst<A> >;
    357   friend class StateIterator< SynchronizeFst<A> >;
    358 
    359   typedef A Arc;
    360   typedef typename A::Weight Weight;
    361   typedef typename A::StateId StateId;
    362   typedef CacheState<A> State;
    363   typedef SynchronizeFstImpl<A> Impl;
    364 
    365   SynchronizeFst(const Fst<A> &fst)
    366       : ImplToFst<Impl>(new Impl(fst, SynchronizeFstOptions())) {}
    367 
    368   SynchronizeFst(const Fst<A> &fst,  const SynchronizeFstOptions &opts)
    369       : ImplToFst<Impl>(new Impl(fst, opts)) {}
    370 
    371   // See Fst<>::Copy() for doc.
    372   SynchronizeFst(const SynchronizeFst<A> &fst, bool safe = false)
    373       : ImplToFst<Impl>(fst, safe) {}
    374 
    375   // Get a copy of this SynchronizeFst. See Fst<>::Copy() for further doc.
    376   virtual SynchronizeFst<A> *Copy(bool safe = false) const {
    377     return new SynchronizeFst<A>(*this, safe);
    378   }
    379 
    380   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    381 
    382   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    383     GetImpl()->InitArcIterator(s, data);
    384   }
    385 
    386  private:
    387   // Makes visible to friends.
    388   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    389 
    390   void operator=(const SynchronizeFst<A> &fst);  // Disallow
    391 };
    392 
    393 
    394 // Specialization for SynchronizeFst.
    395 template<class A>
    396 class StateIterator< SynchronizeFst<A> >
    397     : public CacheStateIterator< SynchronizeFst<A> > {
    398  public:
    399   explicit StateIterator(const SynchronizeFst<A> &fst)
    400       : CacheStateIterator< SynchronizeFst<A> >(fst, fst.GetImpl()) {}
    401 };
    402 
    403 
    404 // Specialization for SynchronizeFst.
    405 template <class A>
    406 class ArcIterator< SynchronizeFst<A> >
    407     : public CacheArcIterator< SynchronizeFst<A> > {
    408  public:
    409   typedef typename A::StateId StateId;
    410 
    411   ArcIterator(const SynchronizeFst<A> &fst, StateId s)
    412       : CacheArcIterator< SynchronizeFst<A> >(fst.GetImpl(), s) {
    413     if (!fst.GetImpl()->HasArcs(s))
    414       fst.GetImpl()->Expand(s);
    415   }
    416 
    417  private:
    418   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    419 };
    420 
    421 
    422 template <class A> inline
    423 void SynchronizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const
    424 {
    425   data->base = new StateIterator< SynchronizeFst<A> >(*this);
    426 }
    427 
    428 
    429 
    430 // Synchronizes a transducer. This version writes the synchronized
    431 // result to a MutableFst.  The result will be an equivalent FST that
    432 // has the property that during the traversal of a path, the delay is
    433 // either zero or strictly increasing, where the delay is the
    434 // difference between the number of non-epsilon output labels and
    435 // input labels along the path.
    436 //
    437 // For the algorithm to terminate, the input transducer must have
    438 // bounded delay, i.e., the delay of every cycle must be zero.
    439 //
    440 // Complexity:
    441 // - A has bounded delay: exponential
    442 // - A does not have bounded delay: does not terminate
    443 //
    444 // References:
    445 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
    446 //   Definitions and Algorithms, International Journal of Computer
    447 //   Science, 14(6): 957-982 (2003).
    448 template<class Arc>
    449 void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
    450   SynchronizeFstOptions opts;
    451   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
    452   *ofst = SynchronizeFst<Arc>(ifst, opts);
    453 }
    454 
    455 }  // namespace fst
    456 
    457 #endif // FST_LIB_SYNCHRONIZE_H__
    458