1 // factor-weight.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Classes to factor weights in an FST. 20 21 #ifndef FST_LIB_FACTOR_WEIGHT_H__ 22 #define FST_LIB_FACTOR_WEIGHT_H__ 23 24 #include <algorithm> 25 #include <tr1/unordered_map> 26 using std::tr1::unordered_map; 27 using std::tr1::unordered_multimap; 28 #include <string> 29 #include <utility> 30 using std::pair; using std::make_pair; 31 #include <vector> 32 using std::vector; 33 34 #include <fst/cache.h> 35 #include <fst/test-properties.h> 36 37 38 namespace fst { 39 40 const uint32 kFactorFinalWeights = 0x00000001; 41 const uint32 kFactorArcWeights = 0x00000002; 42 43 template <class Arc> 44 struct FactorWeightOptions : CacheOptions { 45 typedef typename Arc::Label Label; 46 float delta; 47 uint32 mode; // factor arc weights and/or final weights 48 Label final_ilabel; // input label of arc created when factoring final w's 49 Label final_olabel; // output label of arc created when factoring final w's 50 51 FactorWeightOptions(const CacheOptions &opts, float d, 52 uint32 m = kFactorArcWeights | kFactorFinalWeights, 53 Label il = 0, Label ol = 0) 54 : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), 55 final_olabel(ol) {} 56 57 explicit FactorWeightOptions( 58 float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, 59 Label il = 0, Label ol = 0) 60 : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} 61 62 FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, 63 Label il = 0, Label ol = 0) 64 : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} 65 }; 66 67 68 // A factor iterator takes as argument a weight w and returns a 69 // sequence of pairs of weights (xi,yi) such that the sum of the 70 // products xi times yi is equal to w. If w is fully factored, 71 // the iterator should return nothing. 72 // 73 // template <class W> 74 // class FactorIterator { 75 // public: 76 // FactorIterator(W w); 77 // bool Done() const; 78 // void Next(); 79 // pair<W, W> Value() const; 80 // void Reset(); 81 // } 82 83 84 // Factor trivially. 85 template <class W> 86 class IdentityFactor { 87 public: 88 IdentityFactor(const W &w) {} 89 bool Done() const { return true; } 90 void Next() {} 91 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused 92 void Reset() {} 93 }; 94 95 96 // Factor a StringWeight w as 'ab' where 'a' is a label. 97 template <typename L, StringType S = STRING_LEFT> 98 class StringFactor { 99 public: 100 StringFactor(const StringWeight<L, S> &w) 101 : weight_(w), done_(w.Size() <= 1) {} 102 103 bool Done() const { return done_; } 104 105 void Next() { done_ = true; } 106 107 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { 108 StringWeightIterator<L, S> iter(weight_); 109 StringWeight<L, S> w1(iter.Value()); 110 StringWeight<L, S> w2; 111 for (iter.Next(); !iter.Done(); iter.Next()) 112 w2.PushBack(iter.Value()); 113 return make_pair(w1, w2); 114 } 115 116 void Reset() { done_ = weight_.Size() <= 1; } 117 118 private: 119 StringWeight<L, S> weight_; 120 bool done_; 121 }; 122 123 124 // Factor a GallicWeight using StringFactor. 125 template <class L, class W, StringType S = STRING_LEFT> 126 class GallicFactor { 127 public: 128 GallicFactor(const GallicWeight<L, W, S> &w) 129 : weight_(w), done_(w.Value1().Size() <= 1) {} 130 131 bool Done() const { return done_; } 132 133 void Next() { done_ = true; } 134 135 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { 136 StringFactor<L, S> iter(weight_.Value1()); 137 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); 138 GallicWeight<L, W, S> w2(iter.Value().second, W::One()); 139 return make_pair(w1, w2); 140 } 141 142 void Reset() { done_ = weight_.Value1().Size() <= 1; } 143 144 private: 145 GallicWeight<L, W, S> weight_; 146 bool done_; 147 }; 148 149 150 // Implementation class for FactorWeight 151 template <class A, class F> 152 class FactorWeightFstImpl 153 : public CacheImpl<A> { 154 public: 155 using FstImpl<A>::SetType; 156 using FstImpl<A>::SetProperties; 157 using FstImpl<A>::SetInputSymbols; 158 using FstImpl<A>::SetOutputSymbols; 159 160 using CacheBaseImpl< CacheState<A> >::PushArc; 161 using CacheBaseImpl< CacheState<A> >::HasStart; 162 using CacheBaseImpl< CacheState<A> >::HasFinal; 163 using CacheBaseImpl< CacheState<A> >::HasArcs; 164 using CacheBaseImpl< CacheState<A> >::SetArcs; 165 using CacheBaseImpl< CacheState<A> >::SetFinal; 166 using CacheBaseImpl< CacheState<A> >::SetStart; 167 168 typedef A Arc; 169 typedef typename A::Label Label; 170 typedef typename A::Weight Weight; 171 typedef typename A::StateId StateId; 172 typedef F FactorIterator; 173 174 struct Element { 175 Element() {} 176 177 Element(StateId s, Weight w) : state(s), weight(w) {} 178 179 StateId state; // Input state Id 180 Weight weight; // Residual weight 181 }; 182 183 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts) 184 : CacheImpl<A>(opts), 185 fst_(fst.Copy()), 186 delta_(opts.delta), 187 mode_(opts.mode), 188 final_ilabel_(opts.final_ilabel), 189 final_olabel_(opts.final_olabel) { 190 SetType("factor_weight"); 191 uint64 props = fst.Properties(kFstProperties, false); 192 SetProperties(FactorWeightProperties(props), kCopyProperties); 193 194 SetInputSymbols(fst.InputSymbols()); 195 SetOutputSymbols(fst.OutputSymbols()); 196 197 if (mode_ == 0) 198 LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " 199 << "factoring neither arc weights nor final weights."; 200 } 201 202 FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl) 203 : CacheImpl<A>(impl), 204 fst_(impl.fst_->Copy(true)), 205 delta_(impl.delta_), 206 mode_(impl.mode_), 207 final_ilabel_(impl.final_ilabel_), 208 final_olabel_(impl.final_olabel_) { 209 SetType("factor_weight"); 210 SetProperties(impl.Properties(), kCopyProperties); 211 SetInputSymbols(impl.InputSymbols()); 212 SetOutputSymbols(impl.OutputSymbols()); 213 } 214 215 ~FactorWeightFstImpl() { 216 delete fst_; 217 } 218 219 StateId Start() { 220 if (!HasStart()) { 221 StateId s = fst_->Start(); 222 if (s == kNoStateId) 223 return kNoStateId; 224 StateId start = FindState(Element(fst_->Start(), Weight::One())); 225 SetStart(start); 226 } 227 return CacheImpl<A>::Start(); 228 } 229 230 Weight Final(StateId s) { 231 if (!HasFinal(s)) { 232 const Element &e = elements_[s]; 233 // TODO: fix so cast is unnecessary 234 Weight w = e.state == kNoStateId 235 ? e.weight 236 : (Weight) Times(e.weight, fst_->Final(e.state)); 237 FactorIterator f(w); 238 if (!(mode_ & kFactorFinalWeights) || f.Done()) 239 SetFinal(s, w); 240 else 241 SetFinal(s, Weight::Zero()); 242 } 243 return CacheImpl<A>::Final(s); 244 } 245 246 size_t NumArcs(StateId s) { 247 if (!HasArcs(s)) 248 Expand(s); 249 return CacheImpl<A>::NumArcs(s); 250 } 251 252 size_t NumInputEpsilons(StateId s) { 253 if (!HasArcs(s)) 254 Expand(s); 255 return CacheImpl<A>::NumInputEpsilons(s); 256 } 257 258 size_t NumOutputEpsilons(StateId s) { 259 if (!HasArcs(s)) 260 Expand(s); 261 return CacheImpl<A>::NumOutputEpsilons(s); 262 } 263 264 uint64 Properties() const { return Properties(kFstProperties); } 265 266 // Set error if found; return FST impl properties. 267 uint64 Properties(uint64 mask) const { 268 if ((mask & kError) && fst_->Properties(kError, false)) 269 SetProperties(kError, kError); 270 return FstImpl<Arc>::Properties(mask); 271 } 272 273 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 274 if (!HasArcs(s)) 275 Expand(s); 276 CacheImpl<A>::InitArcIterator(s, data); 277 } 278 279 280 // Find state corresponding to an element. Create new state 281 // if element not found. 282 StateId FindState(const Element &e) { 283 if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { 284 while (unfactored_.size() <= e.state) 285 unfactored_.push_back(kNoStateId); 286 if (unfactored_[e.state] == kNoStateId) { 287 unfactored_[e.state] = elements_.size(); 288 elements_.push_back(e); 289 } 290 return unfactored_[e.state]; 291 } else { 292 typename ElementMap::iterator eit = element_map_.find(e); 293 if (eit != element_map_.end()) { 294 return (*eit).second; 295 } else { 296 StateId s = elements_.size(); 297 elements_.push_back(e); 298 element_map_.insert(pair<const Element, StateId>(e, s)); 299 return s; 300 } 301 } 302 } 303 304 // Computes the outgoing transitions from a state, creating new destination 305 // states as needed. 306 void Expand(StateId s) { 307 Element e = elements_[s]; 308 if (e.state != kNoStateId) { 309 for (ArcIterator< Fst<A> > ait(*fst_, e.state); 310 !ait.Done(); 311 ait.Next()) { 312 const A &arc = ait.Value(); 313 Weight w = Times(e.weight, arc.weight); 314 FactorIterator fit(w); 315 if (!(mode_ & kFactorArcWeights) || fit.Done()) { 316 StateId d = FindState(Element(arc.nextstate, Weight::One())); 317 PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); 318 } else { 319 for (; !fit.Done(); fit.Next()) { 320 const pair<Weight, Weight> &p = fit.Value(); 321 StateId d = FindState(Element(arc.nextstate, 322 p.second.Quantize(delta_))); 323 PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); 324 } 325 } 326 } 327 } 328 329 if ((mode_ & kFactorFinalWeights) && 330 ((e.state == kNoStateId) || 331 (fst_->Final(e.state) != Weight::Zero()))) { 332 Weight w = e.state == kNoStateId 333 ? e.weight 334 : Times(e.weight, fst_->Final(e.state)); 335 for (FactorIterator fit(w); 336 !fit.Done(); 337 fit.Next()) { 338 const pair<Weight, Weight> &p = fit.Value(); 339 StateId d = FindState(Element(kNoStateId, 340 p.second.Quantize(delta_))); 341 PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); 342 } 343 } 344 SetArcs(s); 345 } 346 347 private: 348 static const size_t kPrime = 7853; 349 350 // Equality function for Elements, assume weights have been quantized. 351 class ElementEqual { 352 public: 353 bool operator()(const Element &x, const Element &y) const { 354 return x.state == y.state && x.weight == y.weight; 355 } 356 }; 357 358 // Hash function for Elements to Fst states. 359 class ElementKey { 360 public: 361 size_t operator()(const Element &x) const { 362 return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); 363 } 364 private: 365 }; 366 367 typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; 368 369 const Fst<A> *fst_; 370 float delta_; 371 uint32 mode_; // factoring arc and/or final weights 372 Label final_ilabel_; // ilabel of arc created when factoring final w's 373 Label final_olabel_; // olabel of arc created when factoring final w's 374 vector<Element> elements_; // mapping Fst state to Elements 375 ElementMap element_map_; // mapping Elements to Fst state 376 // mapping between old/new 'StateId' for states that do not need to 377 // be factored when 'mode_' is '0' or 'kFactorFinalWeights' 378 vector<StateId> unfactored_; 379 380 void operator=(const FactorWeightFstImpl<A, F> &); // disallow 381 }; 382 383 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime; 384 385 386 // FactorWeightFst takes as template parameter a FactorIterator as 387 // defined above. The result of weight factoring is a transducer 388 // equivalent to the input whose path weights have been factored 389 // according to the FactorIterator. States and transitions will be 390 // added as necessary. The algorithm is a generalization to arbitrary 391 // weights of the second step of the input epsilon-normalization 392 // algorithm due to Mohri, "Generic epsilon-removal and input 393 // epsilon-normalization algorithms for weighted transducers", 394 // International Journal of Computer Science 13(1): 129-143 (2002). 395 // 396 // This class attaches interface to implementation and handles 397 // reference counting, delegating most methods to ImplToFst. 398 template <class A, class F> 399 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > { 400 public: 401 friend class ArcIterator< FactorWeightFst<A, F> >; 402 friend class StateIterator< FactorWeightFst<A, F> >; 403 404 typedef A Arc; 405 typedef typename A::Weight Weight; 406 typedef typename A::StateId StateId; 407 typedef CacheState<A> State; 408 typedef FactorWeightFstImpl<A, F> Impl; 409 410 FactorWeightFst(const Fst<A> &fst) 411 : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {} 412 413 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts) 414 : ImplToFst<Impl>(new Impl(fst, opts)) {} 415 416 // See Fst<>::Copy() for doc. 417 FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy) 418 : ImplToFst<Impl>(fst, copy) {} 419 420 // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. 421 virtual FactorWeightFst<A, F> *Copy(bool copy = false) const { 422 return new FactorWeightFst<A, F>(*this, copy); 423 } 424 425 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 426 427 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 428 GetImpl()->InitArcIterator(s, data); 429 } 430 431 private: 432 // Makes visible to friends. 433 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 434 435 void operator=(const FactorWeightFst<A, F> &fst); // Disallow 436 }; 437 438 439 // Specialization for FactorWeightFst. 440 template<class A, class F> 441 class StateIterator< FactorWeightFst<A, F> > 442 : public CacheStateIterator< FactorWeightFst<A, F> > { 443 public: 444 explicit StateIterator(const FactorWeightFst<A, F> &fst) 445 : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {} 446 }; 447 448 449 // Specialization for FactorWeightFst. 450 template <class A, class F> 451 class ArcIterator< FactorWeightFst<A, F> > 452 : public CacheArcIterator< FactorWeightFst<A, F> > { 453 public: 454 typedef typename A::StateId StateId; 455 456 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) 457 : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) { 458 if (!fst.GetImpl()->HasArcs(s)) 459 fst.GetImpl()->Expand(s); 460 } 461 462 private: 463 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 464 }; 465 466 template <class A, class F> inline 467 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const 468 { 469 data->base = new StateIterator< FactorWeightFst<A, F> >(*this); 470 } 471 472 473 } // namespace fst 474 475 #endif // FST_LIB_FACTOR_WEIGHT_H__ 476