1 // arcsort.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Functions and classes to sort arcs in an FST. 18 19 #ifndef FST_LIB_ARCSORT_H__ 20 #define FST_LIB_ARCSORT_H__ 21 22 #include <algorithm> 23 24 #include "fst/lib/cache.h" 25 #include "fst/lib/test-properties.h" 26 27 namespace fst { 28 29 // Sorts the arcs in an FST according to function object 'comp' of 30 // type Compare. This version modifies its input. Comparison function 31 // objects IlabelCompare and OlabelCompare are provived by the 32 // library. In general, Compare must meet the requirements for an STL 33 // sort comparision function object. It must also have a member 34 // Properties(uint64) that specifies the known properties of the 35 // sorted FST; it takes as argument the input FST's known properties 36 // before the sort. 37 // 38 // Complexity: 39 // - Time: O(V + D log D) 40 // - Space: O(D) 41 // where V = # of states and D = maximum out-degree. 42 template<class Arc, class Compare> 43 void ArcSort(MutableFst<Arc> *fst, Compare comp) { 44 typedef typename Arc::StateId StateId; 45 46 uint64 props = fst->Properties(kFstProperties, false); 47 48 vector<Arc> arcs; 49 for (StateIterator< MutableFst<Arc> > siter(*fst); 50 !siter.Done(); 51 siter.Next()) { 52 StateId s = siter.Value(); 53 arcs.clear(); 54 for (ArcIterator< MutableFst<Arc> > aiter(*fst, s); 55 !aiter.Done(); 56 aiter.Next()) 57 arcs.push_back(aiter.Value()); 58 sort(arcs.begin(), arcs.end(), comp); 59 fst->DeleteArcs(s); 60 for (size_t a = 0; a < arcs.size(); ++a) 61 fst->AddArc(s, arcs[a]); 62 } 63 64 fst->SetProperties(comp.Properties(props), kFstProperties); 65 } 66 67 typedef CacheOptions ArcSortFstOptions; 68 69 // Implementation of delayed ArcSortFst. 70 template<class A, class C> 71 class ArcSortFstImpl : public CacheImpl<A> { 72 public: 73 using FstImpl<A>::SetType; 74 using FstImpl<A>::SetProperties; 75 using FstImpl<A>::Properties; 76 using FstImpl<A>::SetInputSymbols; 77 using FstImpl<A>::SetOutputSymbols; 78 using FstImpl<A>::InputSymbols; 79 using FstImpl<A>::OutputSymbols; 80 81 using VectorFstBaseImpl<typename CacheImpl<A>::State>::NumStates; 82 83 using CacheImpl<A>::HasArcs; 84 using CacheImpl<A>::HasFinal; 85 using CacheImpl<A>::HasStart; 86 87 typedef typename A::Weight Weight; 88 typedef typename A::StateId StateId; 89 90 ArcSortFstImpl(const Fst<A> &fst, const C &comp, 91 const ArcSortFstOptions &opts) 92 : CacheImpl<A>(opts), fst_(fst.Copy()), comp_(comp) { 93 SetType("arcsort"); 94 uint64 props = fst_->Properties(kCopyProperties, false); 95 SetProperties(comp_.Properties(props)); 96 SetInputSymbols(fst.InputSymbols()); 97 SetOutputSymbols(fst.OutputSymbols()); 98 } 99 100 ArcSortFstImpl(const ArcSortFstImpl& impl) 101 : fst_(impl.fst_->Copy()), comp_(impl.comp_) { 102 SetType("arcsort"); 103 SetProperties(impl.Properties(), kCopyProperties); 104 SetInputSymbols(impl.InputSymbols()); 105 SetOutputSymbols(impl.OutputSymbols()); 106 } 107 108 ~ArcSortFstImpl() { delete fst_; } 109 110 StateId Start() { 111 if (!HasStart()) 112 SetStart(fst_->Start()); 113 return CacheImpl<A>::Start(); 114 } 115 116 Weight Final(StateId s) { 117 if (!HasFinal(s)) 118 SetFinal(s, fst_->Final(s)); 119 return CacheImpl<A>::Final(s); 120 } 121 122 size_t NumArcs(StateId s) { 123 if (!HasArcs(s)) 124 Expand(s); 125 return CacheImpl<A>::NumArcs(s); 126 } 127 128 size_t NumInputEpsilons(StateId s) { 129 if (!HasArcs(s)) 130 Expand(s); 131 return CacheImpl<A>::NumInputEpsilons(s); 132 } 133 134 size_t NumOutputEpsilons(StateId s) { 135 if (!HasArcs(s)) 136 Expand(s); 137 return CacheImpl<A>::NumOutputEpsilons(s); 138 } 139 140 void InitStateIterator(StateIteratorData<A> *data) const { 141 fst_->InitStateIterator(data); 142 } 143 144 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 145 if (!HasArcs(s)) 146 Expand(s); 147 CacheImpl<A>::InitArcIterator(s, data); 148 } 149 150 void Expand(StateId s) { 151 for (ArcIterator< Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) 152 AddArc(s, aiter.Value()); 153 SetArcs(s); 154 155 if (s < NumStates()) { // ensure state exists 156 vector<A> &carcs = GetState(s)->arcs; 157 sort(carcs.begin(), carcs.end(), comp_); 158 } 159 } 160 161 private: 162 const Fst<A> *fst_; 163 C comp_; 164 165 void operator=(const ArcSortFstImpl<A, C> &impl); // Disallow 166 }; 167 168 169 // Sorts the arcs in an FST according to function object 'comp' of 170 // type Compare. This version is a delayed Fst. Comparsion function 171 // objects IlabelCompare and OlabelCompare are provided by the 172 // library. In general, Compare must meet the requirements for an STL 173 // comparision function object (e.g. as used for STL sort). It must 174 // also have a member Properties(uint64) that specifies the known 175 // properties of the sorted FST; it takes as argument the input FST's 176 // known properties. 177 // 178 // Complexity: 179 // - Time: O(v + d log d) 180 // - Space: O(v + d) 181 // where v = # of states visited, d = maximum out-degree of states 182 // visited. Constant time and space to visit an input state is assumed 183 // and exclusive of caching. 184 template <class A, class C> 185 class ArcSortFst : public Fst<A> { 186 public: 187 friend class CacheArcIterator< ArcSortFst<A, C> >; 188 friend class ArcIterator< ArcSortFst<A, C> >; 189 190 typedef A Arc; 191 typedef C Compare; 192 typedef typename A::Weight Weight; 193 typedef typename A::StateId StateId; 194 typedef CacheState<A> State; 195 196 ArcSortFst(const Fst<A> &fst, const C &comp) 197 : impl_(new ArcSortFstImpl<A, C>(fst, comp, ArcSortFstOptions())) {} 198 199 ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts) 200 : impl_(new ArcSortFstImpl<A, C>(fst, comp, opts)) {} 201 202 ArcSortFst(const ArcSortFst<A, C> &fst) : 203 impl_(new ArcSortFstImpl<A, C>(*(fst.impl_))) {} 204 205 virtual ~ArcSortFst() { if (!impl_->DecrRefCount()) delete impl_; } 206 207 virtual StateId Start() const { return impl_->Start(); } 208 209 virtual Weight Final(StateId s) const { return impl_->Final(s); } 210 211 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 212 213 virtual size_t NumInputEpsilons(StateId s) const { 214 return impl_->NumInputEpsilons(s); 215 } 216 217 virtual size_t NumOutputEpsilons(StateId s) const { 218 return impl_->NumOutputEpsilons(s); 219 } 220 221 virtual uint64 Properties(uint64 mask, bool test) const { 222 if (test) { 223 uint64 known, test = TestProperties(*this, mask, &known); 224 impl_->SetProperties(test, known); 225 return test & mask; 226 } else { 227 return impl_->Properties(mask); 228 } 229 } 230 231 virtual const string& Type() const { return impl_->Type(); } 232 233 virtual ArcSortFst<A, C> *Copy() const { 234 return new ArcSortFst<A, C>(*this); 235 } 236 237 virtual const SymbolTable* InputSymbols() const { 238 return impl_->InputSymbols(); 239 } 240 241 virtual const SymbolTable* OutputSymbols() const { 242 return impl_->OutputSymbols(); 243 } 244 245 virtual void InitStateIterator(StateIteratorData<A> *data) const { 246 impl_->InitStateIterator(data); 247 } 248 249 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 250 impl_->InitArcIterator(s, data); 251 } 252 253 private: 254 ArcSortFstImpl<A, C> *impl_; 255 256 void operator=(const ArcSortFst<A, C> &fst); // Disallow 257 }; 258 259 260 // Specialization for ArcSortFst. 261 template <class A, class C> 262 class ArcIterator< ArcSortFst<A, C> > 263 : public CacheArcIterator< ArcSortFst<A, C> > { 264 public: 265 typedef typename A::StateId StateId; 266 267 ArcIterator(const ArcSortFst<A, C> &fst, StateId s) 268 : CacheArcIterator< ArcSortFst<A, C> >(fst, s) { 269 if (!fst.impl_->HasArcs(s)) 270 fst.impl_->Expand(s); 271 } 272 273 private: 274 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 275 }; 276 277 278 // Compare class for comparing input labels of arcs. 279 template<class A> class ILabelCompare { 280 public: 281 bool operator() (A arc1, A arc2) const { 282 return arc1.ilabel < arc2.ilabel; 283 } 284 285 uint64 Properties(uint64 props) const { 286 return props & kArcSortProperties | kILabelSorted; 287 } 288 }; 289 290 291 // Compare class for comparing output labels of arcs. 292 template<class A> class OLabelCompare { 293 public: 294 bool operator() (const A &arc1, const A &arc2) const { 295 return arc1.olabel < arc2.olabel; 296 } 297 298 uint64 Properties(uint64 props) const { 299 return props & kArcSortProperties | kOLabelSorted; 300 } 301 }; 302 303 304 // Useful aliases when using StdArc. 305 template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> { 306 public: 307 typedef StdArc Arc; 308 typedef C Compare; 309 }; 310 311 typedef ILabelCompare<StdArc> StdILabelCompare; 312 313 typedef OLabelCompare<StdArc> StdOLabelCompare; 314 315 } // namespace fst 316 317 #endif // FST_LIB_ARCSORT_H__ 318