Home | History | Annotate | Download | only in lib
      1 // arcsort.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Functions and classes to sort arcs in an FST.
     18 
     19 #ifndef FST_LIB_ARCSORT_H__
     20 #define FST_LIB_ARCSORT_H__
     21 
     22 #include <algorithm>
     23 
     24 #include "fst/lib/cache.h"
     25 #include "fst/lib/test-properties.h"
     26 
     27 namespace fst {
     28 
     29 // Sorts the arcs in an FST according to function object 'comp' of
     30 // type Compare. This version modifies its input.  Comparison function
     31 // objects IlabelCompare and OlabelCompare are provived by the
     32 // library. In general, Compare must meet the requirements for an STL
     33 // sort comparision function object. It must also have a member
     34 // Properties(uint64) that specifies the known properties of the
     35 // sorted FST; it takes as argument the input FST's known properties
     36 // before the sort.
     37 //
     38 // Complexity:
     39 // - Time: O(V + D log D)
     40 // - Space: O(D)
     41 // where V = # of states and D = maximum out-degree.
     42 template<class Arc, class Compare>
     43 void ArcSort(MutableFst<Arc> *fst, Compare comp) {
     44   typedef typename Arc::StateId StateId;
     45 
     46   uint64 props = fst->Properties(kFstProperties, false);
     47 
     48   vector<Arc> arcs;
     49   for (StateIterator< MutableFst<Arc> > siter(*fst);
     50        !siter.Done();
     51        siter.Next()) {
     52     StateId s = siter.Value();
     53     arcs.clear();
     54     for (ArcIterator< MutableFst<Arc> > aiter(*fst, s);
     55          !aiter.Done();
     56          aiter.Next())
     57       arcs.push_back(aiter.Value());
     58     sort(arcs.begin(), arcs.end(), comp);
     59     fst->DeleteArcs(s);
     60     for (size_t a = 0; a < arcs.size(); ++a)
     61       fst->AddArc(s, arcs[a]);
     62   }
     63 
     64   fst->SetProperties(comp.Properties(props), kFstProperties);
     65 }
     66 
     67 typedef CacheOptions ArcSortFstOptions;
     68 
     69 // Implementation of delayed ArcSortFst.
     70 template<class A, class C>
     71 class ArcSortFstImpl : public CacheImpl<A> {
     72  public:
     73   using FstImpl<A>::SetType;
     74   using FstImpl<A>::SetProperties;
     75   using FstImpl<A>::Properties;
     76   using FstImpl<A>::SetInputSymbols;
     77   using FstImpl<A>::SetOutputSymbols;
     78   using FstImpl<A>::InputSymbols;
     79   using FstImpl<A>::OutputSymbols;
     80 
     81   using VectorFstBaseImpl<typename CacheImpl<A>::State>::NumStates;
     82 
     83   using CacheImpl<A>::HasArcs;
     84   using CacheImpl<A>::HasFinal;
     85   using CacheImpl<A>::HasStart;
     86 
     87   typedef typename A::Weight Weight;
     88   typedef typename A::StateId StateId;
     89 
     90   ArcSortFstImpl(const Fst<A> &fst, const C &comp,
     91                  const ArcSortFstOptions &opts)
     92       : CacheImpl<A>(opts), fst_(fst.Copy()), comp_(comp) {
     93     SetType("arcsort");
     94     uint64 props = fst_->Properties(kCopyProperties, false);
     95     SetProperties(comp_.Properties(props));
     96     SetInputSymbols(fst.InputSymbols());
     97     SetOutputSymbols(fst.OutputSymbols());
     98   }
     99 
    100   ArcSortFstImpl(const ArcSortFstImpl& impl)
    101       : fst_(impl.fst_->Copy()), comp_(impl.comp_) {
    102     SetType("arcsort");
    103     SetProperties(impl.Properties(), kCopyProperties);
    104     SetInputSymbols(impl.InputSymbols());
    105     SetOutputSymbols(impl.OutputSymbols());
    106   }
    107 
    108   ~ArcSortFstImpl() { delete fst_; }
    109 
    110   StateId Start() {
    111     if (!HasStart())
    112       SetStart(fst_->Start());
    113     return CacheImpl<A>::Start();
    114   }
    115 
    116   Weight Final(StateId s) {
    117     if (!HasFinal(s))
    118       SetFinal(s, fst_->Final(s));
    119     return CacheImpl<A>::Final(s);
    120   }
    121 
    122   size_t NumArcs(StateId s) {
    123     if (!HasArcs(s))
    124       Expand(s);
    125     return CacheImpl<A>::NumArcs(s);
    126   }
    127 
    128   size_t NumInputEpsilons(StateId s) {
    129     if (!HasArcs(s))
    130       Expand(s);
    131     return CacheImpl<A>::NumInputEpsilons(s);
    132   }
    133 
    134   size_t NumOutputEpsilons(StateId s) {
    135     if (!HasArcs(s))
    136       Expand(s);
    137     return CacheImpl<A>::NumOutputEpsilons(s);
    138   }
    139 
    140   void InitStateIterator(StateIteratorData<A> *data) const {
    141     fst_->InitStateIterator(data);
    142   }
    143 
    144   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    145     if (!HasArcs(s))
    146       Expand(s);
    147     CacheImpl<A>::InitArcIterator(s, data);
    148   }
    149 
    150   void Expand(StateId s) {
    151     for (ArcIterator< Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next())
    152       AddArc(s, aiter.Value());
    153     SetArcs(s);
    154 
    155     if (s < NumStates()) {  // ensure state exists
    156       vector<A> &carcs = GetState(s)->arcs;
    157       sort(carcs.begin(), carcs.end(), comp_);
    158     }
    159   }
    160 
    161  private:
    162   const Fst<A> *fst_;
    163   C comp_;
    164 
    165   void operator=(const ArcSortFstImpl<A, C> &impl);  // Disallow
    166 };
    167 
    168 
    169 // Sorts the arcs in an FST according to function object 'comp' of
    170 // type Compare. This version is a delayed Fst.  Comparsion function
    171 // objects IlabelCompare and OlabelCompare are provided by the
    172 // library. In general, Compare must meet the requirements for an STL
    173 // comparision function object (e.g. as used for STL sort). It must
    174 // also have a member Properties(uint64) that specifies the known
    175 // properties of the sorted FST; it takes as argument the input FST's
    176 // known properties.
    177 //
    178 // Complexity:
    179 // - Time: O(v + d log d)
    180 // - Space: O(v + d)
    181 // where v = # of states visited, d = maximum out-degree of states
    182 // visited. Constant time and space to visit an input state is assumed
    183 // and exclusive of caching.
    184 template <class A, class C>
    185 class ArcSortFst : public Fst<A> {
    186  public:
    187   friend class CacheArcIterator< ArcSortFst<A, C> >;
    188   friend class ArcIterator< ArcSortFst<A, C> >;
    189 
    190   typedef A Arc;
    191   typedef C Compare;
    192   typedef typename A::Weight Weight;
    193   typedef typename A::StateId StateId;
    194   typedef CacheState<A> State;
    195 
    196   ArcSortFst(const Fst<A> &fst, const C &comp)
    197       : impl_(new ArcSortFstImpl<A, C>(fst, comp, ArcSortFstOptions())) {}
    198 
    199   ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts)
    200       : impl_(new ArcSortFstImpl<A, C>(fst, comp, opts)) {}
    201 
    202   ArcSortFst(const ArcSortFst<A, C> &fst) :
    203       impl_(new ArcSortFstImpl<A, C>(*(fst.impl_))) {}
    204 
    205   virtual ~ArcSortFst() { if (!impl_->DecrRefCount()) delete impl_; }
    206 
    207   virtual StateId Start() const { return impl_->Start(); }
    208 
    209   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    210 
    211   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    212 
    213   virtual size_t NumInputEpsilons(StateId s) const {
    214     return impl_->NumInputEpsilons(s);
    215   }
    216 
    217   virtual size_t NumOutputEpsilons(StateId s) const {
    218     return impl_->NumOutputEpsilons(s);
    219   }
    220 
    221   virtual uint64 Properties(uint64 mask, bool test) const {
    222     if (test) {
    223       uint64 known, test = TestProperties(*this, mask, &known);
    224       impl_->SetProperties(test, known);
    225       return test & mask;
    226     } else {
    227       return impl_->Properties(mask);
    228     }
    229   }
    230 
    231   virtual const string& Type() const { return impl_->Type(); }
    232 
    233   virtual ArcSortFst<A, C> *Copy() const {
    234     return new ArcSortFst<A, C>(*this);
    235   }
    236 
    237   virtual const SymbolTable* InputSymbols() const {
    238     return impl_->InputSymbols();
    239   }
    240 
    241   virtual const SymbolTable* OutputSymbols() const {
    242     return impl_->OutputSymbols();
    243   }
    244 
    245   virtual void InitStateIterator(StateIteratorData<A> *data) const {
    246     impl_->InitStateIterator(data);
    247   }
    248 
    249   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    250     impl_->InitArcIterator(s, data);
    251   }
    252 
    253  private:
    254   ArcSortFstImpl<A, C> *impl_;
    255 
    256   void operator=(const ArcSortFst<A, C> &fst);  // Disallow
    257 };
    258 
    259 
    260 // Specialization for ArcSortFst.
    261 template <class A, class C>
    262 class ArcIterator< ArcSortFst<A, C> >
    263     : public CacheArcIterator< ArcSortFst<A, C> > {
    264  public:
    265   typedef typename A::StateId StateId;
    266 
    267   ArcIterator(const ArcSortFst<A, C> &fst, StateId s)
    268       : CacheArcIterator< ArcSortFst<A, C> >(fst, s) {
    269     if (!fst.impl_->HasArcs(s))
    270       fst.impl_->Expand(s);
    271   }
    272 
    273  private:
    274   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    275 };
    276 
    277 
    278 // Compare class for comparing input labels of arcs.
    279 template<class A> class ILabelCompare {
    280  public:
    281   bool operator() (A arc1, A arc2) const {
    282     return arc1.ilabel < arc2.ilabel;
    283   }
    284 
    285   uint64 Properties(uint64 props) const {
    286     return props & kArcSortProperties | kILabelSorted;
    287   }
    288 };
    289 
    290 
    291 // Compare class for comparing output labels of arcs.
    292 template<class A> class OLabelCompare {
    293  public:
    294   bool operator() (const A &arc1, const A &arc2) const {
    295     return arc1.olabel < arc2.olabel;
    296   }
    297 
    298   uint64 Properties(uint64 props) const {
    299     return props & kArcSortProperties | kOLabelSorted;
    300   }
    301 };
    302 
    303 
    304 // Useful aliases when using StdArc.
    305 template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> {
    306  public:
    307   typedef StdArc Arc;
    308   typedef C Compare;
    309 };
    310 
    311 typedef ILabelCompare<StdArc> StdILabelCompare;
    312 
    313 typedef OLabelCompare<StdArc> StdOLabelCompare;
    314 
    315 }  // namespace fst
    316 
    317 #endif  // FST_LIB_ARCSORT_H__
    318