1 // factor-weight.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Classes to factor weights in an FST. 19 20 #ifndef FST_LIB_FACTOR_WEIGHT_H__ 21 #define FST_LIB_FACTOR_WEIGHT_H__ 22 23 #include <algorithm> 24 25 #include <ext/hash_map> 26 using __gnu_cxx::hash_map; 27 #include <ext/slist> 28 using __gnu_cxx::slist; 29 30 #include "fst/lib/cache.h" 31 #include "fst/lib/test-properties.h" 32 33 namespace fst { 34 35 struct FactorWeightOptions : CacheOptions { 36 float delta; 37 bool final_only; // only factor final weights when true 38 39 FactorWeightOptions(const CacheOptions &opts, float d, bool of) 40 : CacheOptions(opts), delta(d), final_only(of) {} 41 42 explicit FactorWeightOptions(float d, bool of = false) 43 : delta(d), final_only(of) {} 44 45 FactorWeightOptions(bool of = false) 46 : delta(kDelta), final_only(of) {} 47 }; 48 49 50 // A factor iterator takes as argument a weight w and returns a 51 // sequence of pairs of weights (xi,yi) such that the sum of the 52 // products xi times yi is equal to w. If w is fully factored, 53 // the iterator should return nothing. 54 // 55 // template <class W> 56 // class FactorIterator { 57 // public: 58 // FactorIterator(W w); 59 // bool Done() const; 60 // void Next(); 61 // pair<W, W> Value() const; 62 // void Reset(); 63 // } 64 65 66 // Factor trivially. 67 template <class W> 68 class IdentityFactor { 69 public: 70 IdentityFactor(const W &w) {} 71 bool Done() const { return true; } 72 void Next() {} 73 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused 74 void Reset() {} 75 }; 76 77 78 // Factor a StringWeight w as 'ab' where 'a' is a label. 79 template <typename L, StringType S = STRING_LEFT> 80 class StringFactor { 81 public: 82 StringFactor(const StringWeight<L, S> &w) 83 : weight_(w), done_(w.Size() <= 1) {} 84 85 bool Done() const { return done_; } 86 87 void Next() { done_ = true; } 88 89 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { 90 StringWeightIterator<L, S> iter(weight_); 91 StringWeight<L, S> w1(iter.Value()); 92 StringWeight<L, S> w2; 93 for (iter.Next(); !iter.Done(); iter.Next()) 94 w2.PushBack(iter.Value()); 95 return make_pair(w1, w2); 96 } 97 98 void Reset() { done_ = weight_.Size() <= 1; } 99 100 private: 101 StringWeight<L, S> weight_; 102 bool done_; 103 }; 104 105 106 // Factor a GallicWeight using StringFactor. 107 template <class L, class W, StringType S = STRING_LEFT> 108 class GallicFactor { 109 public: 110 GallicFactor(const GallicWeight<L, W, S> &w) 111 : weight_(w), done_(w.Value1().Size() <= 1) {} 112 113 bool Done() const { return done_; } 114 115 void Next() { done_ = true; } 116 117 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { 118 StringFactor<L, S> iter(weight_.Value1()); 119 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); 120 GallicWeight<L, W, S> w2(iter.Value().second, W::One()); 121 return make_pair(w1, w2); 122 } 123 124 void Reset() { done_ = weight_.Value1().Size() <= 1; } 125 126 private: 127 GallicWeight<L, W, S> weight_; 128 bool done_; 129 }; 130 131 132 // Implementation class for FactorWeight 133 template <class A, class F> 134 class FactorWeightFstImpl 135 : public CacheImpl<A> { 136 public: 137 using FstImpl<A>::SetType; 138 using FstImpl<A>::SetProperties; 139 using FstImpl<A>::Properties; 140 using FstImpl<A>::SetInputSymbols; 141 using FstImpl<A>::SetOutputSymbols; 142 143 using CacheBaseImpl< CacheState<A> >::HasStart; 144 using CacheBaseImpl< CacheState<A> >::HasFinal; 145 using CacheBaseImpl< CacheState<A> >::HasArcs; 146 147 typedef A Arc; 148 typedef typename A::Label Label; 149 typedef typename A::Weight Weight; 150 typedef typename A::StateId StateId; 151 typedef F FactorIterator; 152 153 struct Element { 154 Element() {} 155 156 Element(StateId s, Weight w) : state(s), weight(w) {} 157 158 StateId state; // Input state Id 159 Weight weight; // Residual weight 160 }; 161 162 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts) 163 : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta), 164 final_only_(opts.final_only) { 165 SetType("factor-weight"); 166 uint64 props = fst.Properties(kFstProperties, false); 167 SetProperties(FactorWeightProperties(props), kCopyProperties); 168 169 SetInputSymbols(fst.InputSymbols()); 170 SetOutputSymbols(fst.OutputSymbols()); 171 } 172 173 ~FactorWeightFstImpl() { 174 delete fst_; 175 } 176 177 StateId Start() { 178 if (!HasStart()) { 179 StateId s = fst_->Start(); 180 if (s == kNoStateId) 181 return kNoStateId; 182 StateId start = FindState(Element(fst_->Start(), Weight::One())); 183 SetStart(start); 184 } 185 return CacheImpl<A>::Start(); 186 } 187 188 Weight Final(StateId s) { 189 if (!HasFinal(s)) { 190 const Element &e = elements_[s]; 191 // TODO: fix so cast is unnecessary 192 Weight w = e.state == kNoStateId 193 ? e.weight 194 : (Weight) Times(e.weight, fst_->Final(e.state)); 195 FactorIterator f(w); 196 if (w != Weight::Zero() && f.Done()) 197 SetFinal(s, w); 198 else 199 SetFinal(s, Weight::Zero()); 200 } 201 return CacheImpl<A>::Final(s); 202 } 203 204 size_t NumArcs(StateId s) { 205 if (!HasArcs(s)) 206 Expand(s); 207 return CacheImpl<A>::NumArcs(s); 208 } 209 210 size_t NumInputEpsilons(StateId s) { 211 if (!HasArcs(s)) 212 Expand(s); 213 return CacheImpl<A>::NumInputEpsilons(s); 214 } 215 216 size_t NumOutputEpsilons(StateId s) { 217 if (!HasArcs(s)) 218 Expand(s); 219 return CacheImpl<A>::NumOutputEpsilons(s); 220 } 221 222 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 223 if (!HasArcs(s)) 224 Expand(s); 225 CacheImpl<A>::InitArcIterator(s, data); 226 } 227 228 229 // Find state corresponding to an element. Create new state 230 // if element not found. 231 StateId FindState(const Element &e) { 232 if (final_only_ && e.weight == Weight::One()) { 233 while (unfactored_.size() <= (unsigned int)e.state) 234 unfactored_.push_back(kNoStateId); 235 if (unfactored_[e.state] == kNoStateId) { 236 unfactored_[e.state] = elements_.size(); 237 elements_.push_back(e); 238 } 239 return unfactored_[e.state]; 240 } else { 241 typename ElementMap::iterator eit = element_map_.find(e); 242 if (eit != element_map_.end()) { 243 return (*eit).second; 244 } else { 245 StateId s = elements_.size(); 246 elements_.push_back(e); 247 element_map_.insert(pair<const Element, StateId>(e, s)); 248 return s; 249 } 250 } 251 } 252 253 // Computes the outgoing transitions from a state, creating new destination 254 // states as needed. 255 void Expand(StateId s) { 256 Element e = elements_[s]; 257 if (e.state != kNoStateId) { 258 for (ArcIterator< Fst<A> > ait(*fst_, e.state); 259 !ait.Done(); 260 ait.Next()) { 261 const A &arc = ait.Value(); 262 Weight w = Times(e.weight, arc.weight); 263 FactorIterator fit(w); 264 if (final_only_ || fit.Done()) { 265 StateId d = FindState(Element(arc.nextstate, Weight::One())); 266 AddArc(s, Arc(arc.ilabel, arc.olabel, w, d)); 267 } else { 268 for (; !fit.Done(); fit.Next()) { 269 const pair<Weight, Weight> &p = fit.Value(); 270 StateId d = FindState(Element(arc.nextstate, 271 p.second.Quantize(delta_))); 272 AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); 273 } 274 } 275 } 276 } 277 if ((e.state == kNoStateId) || 278 (fst_->Final(e.state) != Weight::Zero())) { 279 Weight w = e.state == kNoStateId 280 ? e.weight 281 : Times(e.weight, fst_->Final(e.state)); 282 for (FactorIterator fit(w); 283 !fit.Done(); 284 fit.Next()) { 285 const pair<Weight, Weight> &p = fit.Value(); 286 StateId d = FindState(Element(kNoStateId, 287 p.second.Quantize(delta_))); 288 AddArc(s, Arc(0, 0, p.first, d)); 289 } 290 } 291 SetArcs(s); 292 } 293 294 private: 295 // Equality function for Elements, assume weights have been quantized. 296 class ElementEqual { 297 public: 298 bool operator()(const Element &x, const Element &y) const { 299 return x.state == y.state && x.weight == y.weight; 300 } 301 }; 302 303 // Hash function for Elements to Fst states. 304 class ElementKey { 305 public: 306 size_t operator()(const Element &x) const { 307 return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); 308 } 309 private: 310 static const int kPrime = 7853; 311 }; 312 313 typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap; 314 315 const Fst<A> *fst_; 316 float delta_; 317 bool final_only_; 318 vector<Element> elements_; // mapping Fst state to Elements 319 ElementMap element_map_; // mapping Elements to Fst state 320 // mapping between old/new 'StateId' for states that do not need to 321 // be factored when 'final_only_' is true 322 vector<StateId> unfactored_; 323 324 DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl); 325 }; 326 327 328 // FactorWeightFst takes as template parameter a FactorIterator as 329 // defined above. The result of weight factoring is a transducer 330 // equivalent to the input whose path weights have been factored 331 // according to the FactorIterator. States and transitions will be 332 // added as necessary. The algorithm is a generalization to arbitrary 333 // weights of the second step of the input epsilon-normalization 334 // algorithm due to Mohri, "Generic epsilon-removal and input 335 // epsilon-normalization algorithms for weighted transducers", 336 // International Journal of Computer Science 13(1): 129-143 (2002). 337 template <class A, class F> 338 class FactorWeightFst : public Fst<A> { 339 public: 340 friend class ArcIterator< FactorWeightFst<A, F> >; 341 friend class CacheStateIterator< FactorWeightFst<A, F> >; 342 friend class CacheArcIterator< FactorWeightFst<A, F> >; 343 344 typedef A Arc; 345 typedef typename A::Weight Weight; 346 typedef typename A::StateId StateId; 347 typedef CacheState<A> State; 348 349 FactorWeightFst(const Fst<A> &fst) 350 : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {} 351 352 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions &opts) 353 : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {} 354 FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) { 355 impl_->IncrRefCount(); 356 } 357 358 virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; } 359 360 virtual StateId Start() const { return impl_->Start(); } 361 362 virtual Weight Final(StateId s) const { return impl_->Final(s); } 363 364 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 365 366 virtual size_t NumInputEpsilons(StateId s) const { 367 return impl_->NumInputEpsilons(s); 368 } 369 370 virtual size_t NumOutputEpsilons(StateId s) const { 371 return impl_->NumOutputEpsilons(s); 372 } 373 374 virtual uint64 Properties(uint64 mask, bool test) const { 375 if (test) { 376 uint64 known, test = TestProperties(*this, mask, &known); 377 impl_->SetProperties(test, known); 378 return test & mask; 379 } else { 380 return impl_->Properties(mask); 381 } 382 } 383 384 virtual const string& Type() const { return impl_->Type(); } 385 386 virtual FactorWeightFst<A, F> *Copy() const { 387 return new FactorWeightFst<A, F>(*this); 388 } 389 390 virtual const SymbolTable* InputSymbols() const { 391 return impl_->InputSymbols(); 392 } 393 394 virtual const SymbolTable* OutputSymbols() const { 395 return impl_->OutputSymbols(); 396 } 397 398 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 399 400 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 401 impl_->InitArcIterator(s, data); 402 } 403 404 private: 405 FactorWeightFstImpl<A, F> *Impl() { return impl_; } 406 407 FactorWeightFstImpl<A, F> *impl_; 408 409 void operator=(const FactorWeightFst<A, F> &fst); // Disallow 410 }; 411 412 413 // Specialization for FactorWeightFst. 414 template<class A, class F> 415 class StateIterator< FactorWeightFst<A, F> > 416 : public CacheStateIterator< FactorWeightFst<A, F> > { 417 public: 418 explicit StateIterator(const FactorWeightFst<A, F> &fst) 419 : CacheStateIterator< FactorWeightFst<A, F> >(fst) {} 420 }; 421 422 423 // Specialization for FactorWeightFst. 424 template <class A, class F> 425 class ArcIterator< FactorWeightFst<A, F> > 426 : public CacheArcIterator< FactorWeightFst<A, F> > { 427 public: 428 typedef typename A::StateId StateId; 429 430 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) 431 : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) { 432 if (!fst.impl_->HasArcs(s)) 433 fst.impl_->Expand(s); 434 } 435 436 private: 437 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 438 }; 439 440 template <class A, class F> inline 441 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const 442 { 443 data->base = new StateIterator< FactorWeightFst<A, F> >(*this); 444 } 445 446 447 } // namespace fst 448 449 #endif // FST_LIB_FACTOR_WEIGHT_H__ 450