Home | History | Annotate | Download | only in lib
      1 // synchronize.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Synchronize an FST with bounded delay.
     19 
     20 #ifndef FST_LIB_SYNCHRONIZE_H__
     21 #define FST_LIB_SYNCHRONIZE_H__
     22 
     23 #include <algorithm>
     24 
     25 #include <unordered_map>
     26 #include <unordered_set>
     27 
     28 #include "fst/lib/cache.h"
     29 #include "fst/lib/test-properties.h"
     30 
     31 namespace fst {
     32 
     33 typedef CacheOptions SynchronizeFstOptions;
     34 
     35 
     36 // Implementation class for SynchronizeFst
     37 template <class A>
     38 class SynchronizeFstImpl
     39     : public CacheImpl<A> {
     40  public:
     41   using FstImpl<A>::SetType;
     42   using FstImpl<A>::SetProperties;
     43   using FstImpl<A>::Properties;
     44   using FstImpl<A>::SetInputSymbols;
     45   using FstImpl<A>::SetOutputSymbols;
     46 
     47   using CacheBaseImpl< CacheState<A> >::HasStart;
     48   using CacheBaseImpl< CacheState<A> >::HasFinal;
     49   using CacheBaseImpl< CacheState<A> >::HasArcs;
     50 
     51   typedef A Arc;
     52   typedef typename A::Label Label;
     53   typedef typename A::Weight Weight;
     54   typedef typename A::StateId StateId;
     55 
     56   typedef basic_string<Label> String;
     57 
     58   struct Element {
     59     Element() {}
     60 
     61     Element(StateId s, const String *i, const String *o)
     62         : state(s), istring(i), ostring(o) {}
     63 
     64     StateId state;     // Input state Id
     65     const String *istring;     // Residual input labels
     66     const String *ostring;     // Residual output labels
     67     // Residual strings are represented by const pointers to
     68     // basic_string<Label> and are stored in a hash_set. The pointed
     69     // memory is owned by the hash_set string_set_.
     70   };
     71 
     72   SynchronizeFstImpl(const Fst<A> &fst, const SynchronizeFstOptions &opts)
     73       : CacheImpl<A>(opts), fst_(fst.Copy()) {
     74     SetType("synchronize");
     75     uint64 props = fst.Properties(kFstProperties, false);
     76     SetProperties(SynchronizeProperties(props), kCopyProperties);
     77 
     78     SetInputSymbols(fst.InputSymbols());
     79     SetOutputSymbols(fst.OutputSymbols());
     80   }
     81 
     82   ~SynchronizeFstImpl() {
     83     delete fst_;
     84     // Extract pointers from the hash set
     85     vector<const String*> strings;
     86     typename StringSet::iterator it = string_set_.begin();
     87     for (; it != string_set_.end(); ++it)
     88       strings.push_back(*it);
     89     // Free the extracted pointers
     90     for (size_t i = 0; i < strings.size(); ++i)
     91       delete strings[i];
     92   }
     93 
     94   StateId Start() {
     95     if (!HasStart()) {
     96       StateId s = fst_->Start();
     97       if (s == kNoStateId)
     98         return kNoStateId;
     99       const String *empty = FindString(new String());
    100       StateId start = FindState(Element(fst_->Start(), empty, empty));
    101       SetStart(start);
    102     }
    103     return CacheImpl<A>::Start();
    104   }
    105 
    106   Weight Final(StateId s) {
    107     if (!HasFinal(s)) {
    108       const Element &e = elements_[s];
    109       Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
    110       if ((w != Weight::Zero()) && (e.istring)->empty() && (e.ostring)->empty())
    111         SetFinal(s, w);
    112       else
    113         SetFinal(s, Weight::Zero());
    114     }
    115     return CacheImpl<A>::Final(s);
    116   }
    117 
    118   size_t NumArcs(StateId s) {
    119     if (!HasArcs(s))
    120       Expand(s);
    121     return CacheImpl<A>::NumArcs(s);
    122   }
    123 
    124   size_t NumInputEpsilons(StateId s) {
    125     if (!HasArcs(s))
    126       Expand(s);
    127     return CacheImpl<A>::NumInputEpsilons(s);
    128   }
    129 
    130   size_t NumOutputEpsilons(StateId s) {
    131     if (!HasArcs(s))
    132       Expand(s);
    133     return CacheImpl<A>::NumOutputEpsilons(s);
    134   }
    135 
    136   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    137     if (!HasArcs(s))
    138       Expand(s);
    139     CacheImpl<A>::InitArcIterator(s, data);
    140   }
    141 
    142   // Returns the first character of the string obtained by
    143   // concatenating s and l.
    144   Label Car(const String *s, Label l = 0) const {
    145     if (!s->empty())
    146       return (*s)[0];
    147     else
    148       return l;
    149   }
    150 
    151   // Computes the residual string obtained by removing the first
    152   // character in the concatenation of s and l.
    153   const String *Cdr(const String *s, Label l = 0) {
    154     String *r = new String();
    155     for (int i = 1; i < s->size(); ++i)
    156       r->push_back((*s)[i]);
    157     if (l && !(s->empty())) r->push_back(l);
    158     return FindString(r);
    159   }
    160 
    161   // Computes the concatenation of s and l.
    162   const String *Concat(const String *s, Label l = 0) {
    163     String *r = new String();
    164     for (int i = 0; i < s->size(); ++i)
    165       r->push_back((*s)[i]);
    166     if (l) r->push_back(l);
    167     return FindString(r);
    168   }
    169 
    170   // Tests if the concatenation of s and l is empty
    171   bool Empty(const String *s, Label l = 0) const {
    172     if (s->empty())
    173       return l == 0;
    174     else
    175       return false;
    176   }
    177 
    178   // Finds the string pointed by s in the hash set. Transfers the
    179   // pointer ownership to the hash set.
    180   const String *FindString(const String *s) {
    181     typename StringSet::iterator it = string_set_.find(s);
    182     if (it != string_set_.end()) {
    183       delete s;
    184       return (*it);
    185     } else {
    186       string_set_.insert(s);
    187       return s;
    188     }
    189   }
    190 
    191   // Finds state corresponding to an element. Creates new state
    192   // if element not found.
    193   StateId FindState(const Element &e) {
    194     typename ElementMap::iterator eit = element_map_.find(e);
    195     if (eit != element_map_.end()) {
    196       return (*eit).second;
    197     } else {
    198       StateId s = elements_.size();
    199       elements_.push_back(e);
    200       element_map_.insert(pair<const Element, StateId>(e, s));
    201       return s;
    202     }
    203   }
    204 
    205 
    206   // Computes the outgoing transitions from a state, creating new destination
    207   // states as needed.
    208   void Expand(StateId s) {
    209     Element e = elements_[s];
    210 
    211     if (e.state != kNoStateId)
    212       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
    213            !ait.Done();
    214            ait.Next()) {
    215         const A &arc = ait.Value();
    216         if (!Empty(e.istring, arc.ilabel)  && !Empty(e.ostring, arc.olabel)) {
    217           const String *istring = Cdr(e.istring, arc.ilabel);
    218           const String *ostring = Cdr(e.ostring, arc.olabel);
    219           StateId d = FindState(Element(arc.nextstate, istring, ostring));
    220           AddArc(s, Arc(Car(e.istring, arc.ilabel),
    221                         Car(e.ostring, arc.olabel), arc.weight, d));
    222         } else {
    223           const String *istring = Concat(e.istring, arc.ilabel);
    224           const String *ostring = Concat(e.ostring, arc.olabel);
    225           StateId d = FindState(Element(arc.nextstate, istring, ostring));
    226           AddArc(s, Arc(0 , 0, arc.weight, d));
    227         }
    228       }
    229 
    230     Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
    231     if ((w != Weight::Zero()) &&
    232         ((e.istring)->size() + (e.ostring)->size() > 0)) {
    233       const String *istring = Cdr(e.istring);
    234       const String *ostring = Cdr(e.ostring);
    235       StateId d = FindState(Element(kNoStateId, istring, ostring));
    236       AddArc(s, Arc(Car(e.istring), Car(e.ostring), w, d));
    237     }
    238     SetArcs(s);
    239   }
    240 
    241  private:
    242   // Equality function for Elements, assume strings have been hashed.
    243   class ElementEqual {
    244    public:
    245     bool operator()(const Element &x, const Element &y) const {
    246       return x.state == y.state &&
    247               x.istring == y.istring &&
    248               x.ostring == y.ostring;
    249     }
    250   };
    251 
    252   // Hash function for Elements to Fst states.
    253   class ElementKey {
    254    public:
    255     size_t operator()(const Element &x) const {
    256       size_t key = x.state;
    257       key = (key << 1) ^ (x.istring)->size();
    258       for (size_t i = 0; i < (x.istring)->size(); ++i)
    259         key = (key << 1) ^ (*x.istring)[i];
    260       key = (key << 1) ^ (x.ostring)->size();
    261       for (size_t i = 0; i < (x.ostring)->size(); ++i)
    262         key = (key << 1) ^ (*x.ostring)[i];
    263       return key;
    264     }
    265   };
    266 
    267   // Equality function for strings
    268   class StringEqual {
    269    public:
    270     bool operator()(const String * const &x, const String * const &y) const {
    271       if (x->size() != y->size()) return false;
    272       for (size_t i = 0; i < x->size(); ++i)
    273         if ((*x)[i] != (*y)[i]) return false;
    274       return true;
    275     }
    276   };
    277 
    278   // Hash function for set of strings
    279   class StringKey{
    280    public:
    281     size_t operator()(const String * const & x) const {
    282       size_t key = x->size();
    283       for (size_t i = 0; i < x->size(); ++i)
    284         key = (key << 1) ^ (*x)[i];
    285       return key;
    286     }
    287   };
    288 
    289 
    290   typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
    291   typedef std::unordered_set<const String*, StringKey, StringEqual> StringSet;
    292 
    293   const Fst<A> *fst_;
    294   vector<Element> elements_;  // mapping Fst state to Elements
    295   ElementMap element_map_;    // mapping Elements to Fst state
    296   StringSet string_set_;
    297 
    298   DISALLOW_EVIL_CONSTRUCTORS(SynchronizeFstImpl);
    299 };
    300 
    301 
    302 // Synchronizes a transducer. This version is a delayed Fst.  The
    303 // result will be an equivalent FST that has the property that during
    304 // the traversal of a path, the delay is either zero or strictly
    305 // increasing, where the delay is the difference between the number of
    306 // non-epsilon output labels and input labels along the path.
    307 //
    308 // For the algorithm to terminate, the input transducer must have
    309 // bounded delay, i.e., the delay of every cycle must be zero.
    310 //
    311 // Complexity:
    312 // - A has bounded delay: exponential
    313 // - A does not have bounded delay: does not terminate
    314 //
    315 // References:
    316 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
    317 //   Definitions and Algorithms, International Journal of Computer
    318 //   Science, 14(6): 957-982 (2003).
    319 template <class A>
    320 class SynchronizeFst : public Fst<A> {
    321  public:
    322   friend class ArcIterator< SynchronizeFst<A> >;
    323   friend class CacheStateIterator< SynchronizeFst<A> >;
    324   friend class CacheArcIterator< SynchronizeFst<A> >;
    325 
    326   typedef A Arc;
    327   typedef typename A::Weight Weight;
    328   typedef typename A::StateId StateId;
    329   typedef CacheState<A> State;
    330 
    331   SynchronizeFst(const Fst<A> &fst)
    332       : impl_(new SynchronizeFstImpl<A>(fst, SynchronizeFstOptions())) {}
    333 
    334   SynchronizeFst(const Fst<A> &fst,  const SynchronizeFstOptions &opts)
    335       : impl_(new SynchronizeFstImpl<A>(fst, opts)) {}
    336 
    337   SynchronizeFst(const SynchronizeFst<A> &fst) : impl_(fst.impl_) {
    338     impl_->IncrRefCount();
    339   }
    340 
    341   virtual ~SynchronizeFst() { if (!impl_->DecrRefCount()) delete impl_; }
    342 
    343   virtual StateId Start() const { return impl_->Start(); }
    344 
    345   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    346 
    347   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    348 
    349   virtual size_t NumInputEpsilons(StateId s) const {
    350     return impl_->NumInputEpsilons(s);
    351   }
    352 
    353   virtual size_t NumOutputEpsilons(StateId s) const {
    354     return impl_->NumOutputEpsilons(s);
    355   }
    356 
    357   virtual uint64 Properties(uint64 mask, bool test) const {
    358     if (test) {
    359       uint64 known, test = TestProperties(*this, mask, &known);
    360       impl_->SetProperties(test, known);
    361       return test & mask;
    362     } else {
    363       return impl_->Properties(mask);
    364     }
    365   }
    366 
    367   virtual const string& Type() const { return impl_->Type(); }
    368 
    369   virtual SynchronizeFst<A> *Copy() const {
    370     return new SynchronizeFst<A>(*this);
    371   }
    372 
    373   virtual const SymbolTable* InputSymbols() const {
    374     return impl_->InputSymbols();
    375   }
    376 
    377   virtual const SymbolTable* OutputSymbols() const {
    378     return impl_->OutputSymbols();
    379   }
    380 
    381   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    382 
    383   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    384     impl_->InitArcIterator(s, data);
    385   }
    386 
    387  private:
    388   SynchronizeFstImpl<A> *Impl() { return impl_; }
    389 
    390   SynchronizeFstImpl<A> *impl_;
    391 
    392   void operator=(const SynchronizeFst<A> &fst);  // Disallow
    393 };
    394 
    395 
    396 // Specialization for SynchronizeFst.
    397 template<class A>
    398 class StateIterator< SynchronizeFst<A> >
    399     : public CacheStateIterator< SynchronizeFst<A> > {
    400  public:
    401   explicit StateIterator(const SynchronizeFst<A> &fst)
    402       : CacheStateIterator< SynchronizeFst<A> >(fst) {}
    403 };
    404 
    405 
    406 // Specialization for SynchronizeFst.
    407 template <class A>
    408 class ArcIterator< SynchronizeFst<A> >
    409     : public CacheArcIterator< SynchronizeFst<A> > {
    410  public:
    411   typedef typename A::StateId StateId;
    412 
    413   ArcIterator(const SynchronizeFst<A> &fst, StateId s)
    414       : CacheArcIterator< SynchronizeFst<A> >(fst, s) {
    415     if (!fst.impl_->HasArcs(s))
    416       fst.impl_->Expand(s);
    417   }
    418 
    419  private:
    420   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    421 };
    422 
    423 
    424 template <class A> inline
    425 void SynchronizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const
    426 {
    427   data->base = new StateIterator< SynchronizeFst<A> >(*this);
    428 }
    429 
    430 
    431 
    432 // Synchronizes a transducer. This version writes the synchronized
    433 // result to a MutableFst.  The result will be an equivalent FST that
    434 // has the property that during the traversal of a path, the delay is
    435 // either zero or strictly increasing, where the delay is the
    436 // difference between the number of non-epsilon output labels and
    437 // input labels along the path.
    438 //
    439 // For the algorithm to terminate, the input transducer must have
    440 // bounded delay, i.e., the delay of every cycle must be zero.
    441 //
    442 // Complexity:
    443 // - A has bounded delay: exponential
    444 // - A does not have bounded delay: does not terminate
    445 //
    446 // References:
    447 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
    448 //   Definitions and Algorithms, International Journal of Computer
    449 //   Science, 14(6): 957-982 (2003).
    450 template<class Arc>
    451 void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
    452   SynchronizeFstOptions opts;
    453   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
    454   *ofst = SynchronizeFst<Arc>(ifst, opts);
    455 }
    456 
    457 }  // namespace fst
    458 
    459 #endif // FST_LIB_SYNCHRONIZE_H__
    460