1 // map.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Class to map over/transform states e.g., sort transitions 20 // Consider using when operation does not change the number of states. 21 22 #ifndef FST_LIB_STATE_MAP_H__ 23 #define FST_LIB_STATE_MAP_H__ 24 25 #include <algorithm> 26 #include <tr1/unordered_map> 27 using std::tr1::unordered_map; 28 using std::tr1::unordered_multimap; 29 #include <string> 30 #include <utility> 31 using std::pair; using std::make_pair; 32 33 #include <fst/cache.h> 34 #include <fst/arc-map.h> 35 #include <fst/mutable-fst.h> 36 37 38 namespace fst { 39 40 // StateMapper Interface - class determinies how states are mapped. 41 // Useful for implementing operations that do not change the number of states. 42 // 43 // class StateMapper { 44 // public: 45 // typedef A FromArc; 46 // typedef B ToArc; 47 // 48 // // Typical constructor 49 // StateMapper(const Fst<A> &fst); 50 // // Required copy constructor that allows updating Fst argument; 51 // // pass only if relevant and changed. 52 // StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0); 53 // 54 // // Specifies initial state of result 55 // B::StateId Start() const; 56 // // Specifies state's final weight in result 57 // B::Weight Final(B::StateId s) const; 58 // 59 // // These methods iterate through a state's arcs in result 60 // // Specifies state to iterate over 61 // void SetState(B::StateId s); 62 // // End of arcs? 63 // bool Done() const; 64 // // Current arc 65 66 // const B &Value() const; 67 // // Advance to next arc (when !Done) 68 // void Next(); 69 // 70 // // Specifies input symbol table action the mapper requires (see above). 71 // MapSymbolsAction InputSymbolsAction() const; 72 // // Specifies output symbol table action the mapper requires (see above). 73 // MapSymbolsAction OutputSymbolsAction() const; 74 // // This specifies the known properties of an Fst mapped by this 75 // // mapper. It takes as argument the input Fst's known properties. 76 // uint64 Properties(uint64 props) const; 77 // }; 78 // 79 // We include a various state map versions below. One dimension of 80 // variation is whether the mapping mutates its input, writes to a 81 // new result Fst, or is an on-the-fly Fst. Another dimension is how 82 // we pass the mapper. We allow passing the mapper by pointer 83 // for cases that we need to change the state of the user's mapper. 84 // We also include map versions that pass the mapper 85 // by value or const reference when this suffices. 86 87 // Maps an arc type A using a mapper function object C, passed 88 // by pointer. This version modifies its Fst input. 89 template<class A, class C> 90 void StateMap(MutableFst<A> *fst, C* mapper) { 91 typedef typename A::StateId StateId; 92 typedef typename A::Weight Weight; 93 94 if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 95 fst->SetInputSymbols(0); 96 97 if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 98 fst->SetOutputSymbols(0); 99 100 if (fst->Start() == kNoStateId) 101 return; 102 103 uint64 props = fst->Properties(kFstProperties, false); 104 105 fst->SetStart(mapper->Start()); 106 107 for (StateId s = 0; s < fst->NumStates(); ++s) { 108 mapper->SetState(s); 109 fst->DeleteArcs(s); 110 for (; !mapper->Done(); mapper->Next()) 111 fst->AddArc(s, mapper->Value()); 112 fst->SetFinal(s, mapper->Final(s)); 113 } 114 115 fst->SetProperties(mapper->Properties(props), kFstProperties); 116 } 117 118 // Maps an arc type A using a mapper function object C, passed 119 // by value. This version modifies its Fst input. 120 template<class A, class C> 121 void StateMap(MutableFst<A> *fst, C mapper) { 122 StateMap(fst, &mapper); 123 } 124 125 126 // Maps an arc type A to an arc type B using mapper function 127 // object C, passed by pointer. This version writes the mapped 128 // input Fst to an output MutableFst. 129 template<class A, class B, class C> 130 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { 131 typedef typename A::StateId StateId; 132 typedef typename A::Weight Weight; 133 134 ofst->DeleteStates(); 135 136 if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) 137 ofst->SetInputSymbols(ifst.InputSymbols()); 138 else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 139 ofst->SetInputSymbols(0); 140 141 if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) 142 ofst->SetOutputSymbols(ifst.OutputSymbols()); 143 else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 144 ofst->SetOutputSymbols(0); 145 146 uint64 iprops = ifst.Properties(kCopyProperties, false); 147 148 if (ifst.Start() == kNoStateId) { 149 if (iprops & kError) ofst->SetProperties(kError, kError); 150 return; 151 } 152 153 // Add all states. 154 if (ifst.Properties(kExpanded, false)) 155 ofst->ReserveStates(CountStates(ifst)); 156 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) 157 ofst->AddState(); 158 159 ofst->SetStart(mapper->Start()); 160 161 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) { 162 StateId s = siter.Value(); 163 mapper->SetState(s); 164 for (; !mapper->Done(); mapper->Next()) 165 ofst->AddArc(s, mapper->Value()); 166 ofst->SetFinal(s, mapper->Final(s)); 167 } 168 169 uint64 oprops = ofst->Properties(kFstProperties, false); 170 ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); 171 } 172 173 // Maps an arc type A to an arc type B using mapper function 174 // object C, passed by value. This version writes the mapped input 175 // Fst to an output MutableFst. 176 template<class A, class B, class C> 177 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { 178 StateMap(ifst, ofst, &mapper); 179 } 180 181 typedef CacheOptions StateMapFstOptions; 182 183 template <class A, class B, class C> class StateMapFst; 184 185 // Implementation of delayed StateMapFst. 186 template <class A, class B, class C> 187 class StateMapFstImpl : public CacheImpl<B> { 188 public: 189 using FstImpl<B>::SetType; 190 using FstImpl<B>::SetProperties; 191 using FstImpl<B>::SetInputSymbols; 192 using FstImpl<B>::SetOutputSymbols; 193 194 using CacheImpl<B>::PushArc; 195 using CacheImpl<B>::HasArcs; 196 using CacheImpl<B>::HasFinal; 197 using CacheImpl<B>::HasStart; 198 using CacheImpl<B>::SetArcs; 199 using CacheImpl<B>::SetFinal; 200 using CacheImpl<B>::SetStart; 201 202 friend class StateIterator< StateMapFst<A, B, C> >; 203 204 typedef B Arc; 205 typedef typename B::Weight Weight; 206 typedef typename B::StateId StateId; 207 208 StateMapFstImpl(const Fst<A> &fst, const C &mapper, 209 const StateMapFstOptions& opts) 210 : CacheImpl<B>(opts), 211 fst_(fst.Copy()), 212 mapper_(new C(mapper, fst_)), 213 own_mapper_(true) { 214 Init(); 215 } 216 217 StateMapFstImpl(const Fst<A> &fst, C *mapper, 218 const StateMapFstOptions& opts) 219 : CacheImpl<B>(opts), 220 fst_(fst.Copy()), 221 mapper_(mapper), 222 own_mapper_(false) { 223 Init(); 224 } 225 226 StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl) 227 : CacheImpl<B>(impl), 228 fst_(impl.fst_->Copy(true)), 229 mapper_(new C(*impl.mapper_, fst_)), 230 own_mapper_(true) { 231 Init(); 232 } 233 234 ~StateMapFstImpl() { 235 delete fst_; 236 if (own_mapper_) delete mapper_; 237 } 238 239 StateId Start() { 240 if (!HasStart()) 241 SetStart(mapper_->Start()); 242 return CacheImpl<B>::Start(); 243 } 244 245 Weight Final(StateId s) { 246 if (!HasFinal(s)) 247 SetFinal(s, mapper_->Final(s)); 248 return CacheImpl<B>::Final(s); 249 } 250 251 size_t NumArcs(StateId s) { 252 if (!HasArcs(s)) 253 Expand(s); 254 return CacheImpl<B>::NumArcs(s); 255 } 256 257 size_t NumInputEpsilons(StateId s) { 258 if (!HasArcs(s)) 259 Expand(s); 260 return CacheImpl<B>::NumInputEpsilons(s); 261 } 262 263 size_t NumOutputEpsilons(StateId s) { 264 if (!HasArcs(s)) 265 Expand(s); 266 return CacheImpl<B>::NumOutputEpsilons(s); 267 } 268 269 void InitStateIterator(StateIteratorData<A> *data) const { 270 fst_->InitStateIterator(data); 271 } 272 273 void InitArcIterator(StateId s, ArcIteratorData<B> *data) { 274 if (!HasArcs(s)) 275 Expand(s); 276 CacheImpl<B>::InitArcIterator(s, data); 277 } 278 279 uint64 Properties() const { return Properties(kFstProperties); } 280 281 // Set error if found; return FST impl properties. 282 uint64 Properties(uint64 mask) const { 283 if ((mask & kError) && (fst_->Properties(kError, false) || 284 (mapper_->Properties(0) & kError))) 285 SetProperties(kError, kError); 286 return FstImpl<Arc>::Properties(mask); 287 } 288 289 void Expand(StateId s) { 290 // Add exiting arcs. 291 for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next()) 292 PushArc(s, mapper_->Value()); 293 SetArcs(s); 294 } 295 296 const Fst<A> &GetFst() const { 297 return *fst_; 298 } 299 300 private: 301 void Init() { 302 SetType("statemap"); 303 304 if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) 305 SetInputSymbols(fst_->InputSymbols()); 306 else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 307 SetInputSymbols(0); 308 309 if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) 310 SetOutputSymbols(fst_->OutputSymbols()); 311 else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 312 SetOutputSymbols(0); 313 314 uint64 props = fst_->Properties(kCopyProperties, false); 315 SetProperties(mapper_->Properties(props)); 316 } 317 318 const Fst<A> *fst_; 319 C* mapper_; 320 bool own_mapper_; 321 322 void operator=(const StateMapFstImpl<A, B, C> &); // disallow 323 }; 324 325 326 // Maps an arc type A to an arc type B using Mapper function object 327 // C. This version is a delayed Fst. 328 template <class A, class B, class C> 329 class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > { 330 public: 331 friend class ArcIterator< StateMapFst<A, B, C> >; 332 333 typedef B Arc; 334 typedef typename B::Weight Weight; 335 typedef typename B::StateId StateId; 336 typedef CacheState<B> State; 337 typedef StateMapFstImpl<A, B, C> Impl; 338 339 StateMapFst(const Fst<A> &fst, const C &mapper, 340 const StateMapFstOptions& opts) 341 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} 342 343 StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts) 344 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} 345 346 StateMapFst(const Fst<A> &fst, const C &mapper) 347 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} 348 349 StateMapFst(const Fst<A> &fst, C* mapper) 350 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} 351 352 // See Fst<>::Copy() for doc. 353 StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false) 354 : ImplToFst<Impl>(fst, safe) {} 355 356 // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. 357 virtual StateMapFst<A, B, C> *Copy(bool safe = false) const { 358 return new StateMapFst<A, B, C>(*this, safe); 359 } 360 361 virtual void InitStateIterator(StateIteratorData<A> *data) const { 362 GetImpl()->InitStateIterator(data); 363 } 364 365 virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { 366 GetImpl()->InitArcIterator(s, data); 367 } 368 369 protected: 370 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 371 372 private: 373 void operator=(const StateMapFst<A, B, C> &fst); // disallow 374 }; 375 376 377 // Specialization for StateMapFst. 378 template <class A, class B, class C> 379 class ArcIterator< StateMapFst<A, B, C> > 380 : public CacheArcIterator< StateMapFst<A, B, C> > { 381 public: 382 typedef typename A::StateId StateId; 383 384 ArcIterator(const StateMapFst<A, B, C> &fst, StateId s) 385 : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) { 386 if (!fst.GetImpl()->HasArcs(s)) 387 fst.GetImpl()->Expand(s); 388 } 389 390 private: 391 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 392 }; 393 394 // 395 // Utility Mappers 396 // 397 398 // Mapper that returns its input. 399 template <class A> 400 class IdentityStateMapper { 401 public: 402 typedef A FromArc; 403 typedef A ToArc; 404 405 typedef typename A::StateId StateId; 406 typedef typename A::Weight Weight; 407 408 explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {} 409 410 // Allows updating Fst argument; pass only if changed. 411 IdentityStateMapper(const IdentityStateMapper<A> &mapper, 412 const Fst<A> *fst = 0) 413 : fst_(fst ? *fst : mapper.fst_), aiter_(0) {} 414 415 ~IdentityStateMapper() { delete aiter_; } 416 417 StateId Start() const { return fst_.Start(); } 418 419 Weight Final(StateId s) const { return fst_.Final(s); } 420 421 void SetState(StateId s) { 422 if (aiter_) delete aiter_; 423 aiter_ = new ArcIterator< Fst<A> >(fst_, s); 424 } 425 426 bool Done() const { return aiter_->Done(); } 427 const A &Value() const { return aiter_->Value(); } 428 void Next() { aiter_->Next(); } 429 430 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 431 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} 432 433 uint64 Properties(uint64 props) const { return props; } 434 435 private: 436 const Fst<A> &fst_; 437 ArcIterator< Fst<A> > *aiter_; 438 }; 439 440 template <class A> 441 class ArcSumMapper { 442 public: 443 typedef A FromArc; 444 typedef A ToArc; 445 446 typedef typename A::StateId StateId; 447 typedef typename A::Weight Weight; 448 449 explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} 450 451 // Allows updating Fst argument; pass only if changed. 452 ArcSumMapper(const ArcSumMapper<A> &mapper, 453 const Fst<A> *fst = 0) 454 : fst_(fst ? *fst : mapper.fst_), i_(0) {} 455 456 StateId Start() const { return fst_.Start(); } 457 Weight Final(StateId s) const { return fst_.Final(s); } 458 459 void SetState(StateId s) { 460 i_ = 0; 461 arcs_.clear(); 462 arcs_.reserve(fst_.NumArcs(s)); 463 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) 464 arcs_.push_back(aiter.Value()); 465 466 // First sorts the exiting arcs by input label, output label 467 // and destination state and then sums weights of arcs with 468 // the same input label, output label, and destination state. 469 sort(arcs_.begin(), arcs_.end(), comp_); 470 size_t narcs = 0; 471 for (size_t i = 0; i < arcs_.size(); ++i) { 472 if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) { 473 arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, 474 arcs_[i].weight); 475 } else { 476 arcs_[narcs++] = arcs_[i]; 477 } 478 } 479 arcs_.resize(narcs); 480 } 481 482 bool Done() const { return i_ >= arcs_.size(); } 483 const A &Value() const { return arcs_[i_]; } 484 void Next() { ++i_; } 485 486 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 487 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 488 489 uint64 Properties(uint64 props) const { 490 return props & kArcSortProperties & 491 kDeleteArcsProperties & kWeightInvariantProperties; 492 } 493 494 private: 495 struct Compare { 496 bool operator()(const A& x, const A& y) { 497 if (x.ilabel < y.ilabel) return true; 498 if (x.ilabel > y.ilabel) return false; 499 if (x.olabel < y.olabel) return true; 500 if (x.olabel > y.olabel) return false; 501 if (x.nextstate < y.nextstate) return true; 502 if (x.nextstate > y.nextstate) return false; 503 return false; 504 } 505 }; 506 507 struct Equal { 508 bool operator()(const A& x, const A& y) { 509 return (x.ilabel == y.ilabel && 510 x.olabel == y.olabel && 511 x.nextstate == y.nextstate); 512 } 513 }; 514 515 const Fst<A> &fst_; 516 Compare comp_; 517 Equal equal_; 518 vector<A> arcs_; 519 ssize_t i_; // current arc position 520 521 void operator=(const ArcSumMapper<A> &); // disallow 522 }; 523 524 template <class A> 525 class ArcUniqueMapper { 526 public: 527 typedef A FromArc; 528 typedef A ToArc; 529 530 typedef typename A::StateId StateId; 531 typedef typename A::Weight Weight; 532 533 explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} 534 535 // Allows updating Fst argument; pass only if changed. 536 ArcUniqueMapper(const ArcUniqueMapper<A> &mapper, 537 const Fst<A> *fst = 0) 538 : fst_(fst ? *fst : mapper.fst_), i_(0) {} 539 540 StateId Start() const { return fst_.Start(); } 541 Weight Final(StateId s) const { return fst_.Final(s); } 542 543 void SetState(StateId s) { 544 i_ = 0; 545 arcs_.clear(); 546 arcs_.reserve(fst_.NumArcs(s)); 547 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) 548 arcs_.push_back(aiter.Value()); 549 550 // First sorts the exiting arcs by input label, output label 551 // and destination state and then uniques identical arcs 552 sort(arcs_.begin(), arcs_.end(), comp_); 553 typename vector<A>::iterator unique_end = 554 unique(arcs_.begin(), arcs_.end(), equal_); 555 arcs_.resize(unique_end - arcs_.begin()); 556 } 557 558 bool Done() const { return i_ >= arcs_.size(); } 559 const A &Value() const { return arcs_[i_]; } 560 void Next() { ++i_; } 561 562 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 563 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 564 565 uint64 Properties(uint64 props) const { 566 return props & kArcSortProperties & kDeleteArcsProperties; 567 } 568 569 private: 570 struct Compare { 571 bool operator()(const A& x, const A& y) { 572 if (x.ilabel < y.ilabel) return true; 573 if (x.ilabel > y.ilabel) return false; 574 if (x.olabel < y.olabel) return true; 575 if (x.olabel > y.olabel) return false; 576 if (x.nextstate < y.nextstate) return true; 577 if (x.nextstate > y.nextstate) return false; 578 return false; 579 } 580 }; 581 582 struct Equal { 583 bool operator()(const A& x, const A& y) { 584 return (x.ilabel == y.ilabel && 585 x.olabel == y.olabel && 586 x.nextstate == y.nextstate && 587 x.weight == y.weight); 588 } 589 }; 590 591 const Fst<A> &fst_; 592 Compare comp_; 593 Equal equal_; 594 vector<A> arcs_; 595 ssize_t i_; // current arc position 596 597 void operator=(const ArcUniqueMapper<A> &); // disallow 598 }; 599 600 601 } // namespace fst 602 603 #endif // FST_LIB_STATE_MAP_H__ 604