1 // relabel.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: johans (at) google.com (Johan Schalkwyk) 17 // 18 // \file 19 // Functions and classes to relabel an Fst (either on input or output) 20 // 21 #ifndef FST_LIB_RELABEL_H__ 22 #define FST_LIB_RELABEL_H__ 23 24 #include <tr1/unordered_map> 25 using std::tr1::unordered_map; 26 using std::tr1::unordered_multimap; 27 #include <string> 28 #include <utility> 29 using std::pair; using std::make_pair; 30 #include <vector> 31 using std::vector; 32 33 #include <fst/cache.h> 34 #include <fst/test-properties.h> 35 36 37 #include <tr1/unordered_map> 38 using std::tr1::unordered_map; 39 using std::tr1::unordered_multimap; 40 41 namespace fst { 42 43 // 44 // Relabels either the input labels or output labels. The old to 45 // new labels are specified using a vector of pair<Label,Label>. 46 // Any label associations not specified are assumed to be identity 47 // mapping. 48 // 49 // \param fst input fst, must be mutable 50 // \param ipairs vector of input label pairs indicating old to new mapping 51 // \param opairs vector of output label pairs indicating old to new mapping 52 // 53 template <class A> 54 void Relabel( 55 MutableFst<A> *fst, 56 const vector<pair<typename A::Label, typename A::Label> >& ipairs, 57 const vector<pair<typename A::Label, typename A::Label> >& opairs) { 58 typedef typename A::StateId StateId; 59 typedef typename A::Label Label; 60 61 uint64 props = fst->Properties(kFstProperties, false); 62 63 // construct label to label hash. 64 unordered_map<Label, Label> input_map; 65 for (size_t i = 0; i < ipairs.size(); ++i) { 66 input_map[ipairs[i].first] = ipairs[i].second; 67 } 68 69 unordered_map<Label, Label> output_map; 70 for (size_t i = 0; i < opairs.size(); ++i) { 71 output_map[opairs[i].first] = opairs[i].second; 72 } 73 74 for (StateIterator<MutableFst<A> > siter(*fst); 75 !siter.Done(); siter.Next()) { 76 StateId s = siter.Value(); 77 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 78 !aiter.Done(); aiter.Next()) { 79 A arc = aiter.Value(); 80 81 // relabel input 82 // only relabel if relabel pair defined 83 typename unordered_map<Label, Label>::iterator it = 84 input_map.find(arc.ilabel); 85 if (it != input_map.end()) { 86 if (it->second == kNoLabel) { 87 FSTERROR() << "Input symbol id " << arc.ilabel 88 << " missing from target vocabulary"; 89 fst->SetProperties(kError, kError); 90 return; 91 } 92 arc.ilabel = it->second; 93 } 94 95 // relabel output 96 it = output_map.find(arc.olabel); 97 if (it != output_map.end()) { 98 if (it->second == kNoLabel) { 99 FSTERROR() << "Output symbol id " << arc.olabel 100 << " missing from target vocabulary"; 101 fst->SetProperties(kError, kError); 102 return; 103 } 104 arc.olabel = it->second; 105 } 106 107 aiter.SetValue(arc); 108 } 109 } 110 111 fst->SetProperties(RelabelProperties(props), kFstProperties); 112 } 113 114 // 115 // Relabels either the input labels or output labels. The old to 116 // new labels mappings are specified using an input Symbol set. 117 // Any label associations not specified are assumed to be identity 118 // mapping. 119 // 120 // \param fst input fst, must be mutable 121 // \param new_isymbols symbol set indicating new mapping of input symbols 122 // \param new_osymbols symbol set indicating new mapping of output symbols 123 // 124 template<class A> 125 void Relabel(MutableFst<A> *fst, 126 const SymbolTable* new_isymbols, 127 const SymbolTable* new_osymbols) { 128 Relabel(fst, 129 fst->InputSymbols(), new_isymbols, true, 130 fst->OutputSymbols(), new_osymbols, true); 131 } 132 133 template<class A> 134 void Relabel(MutableFst<A> *fst, 135 const SymbolTable* old_isymbols, 136 const SymbolTable* new_isymbols, 137 bool attach_new_isymbols, 138 const SymbolTable* old_osymbols, 139 const SymbolTable* new_osymbols, 140 bool attach_new_osymbols) { 141 typedef typename A::StateId StateId; 142 typedef typename A::Label Label; 143 144 vector<pair<Label, Label> > ipairs; 145 if (old_isymbols && new_isymbols) { 146 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 147 syms_iter.Next()) { 148 string isymbol = syms_iter.Symbol(); 149 int isymbol_val = syms_iter.Value(); 150 int new_isymbol_val = new_isymbols->Find(isymbol); 151 ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); 152 } 153 if (attach_new_isymbols) 154 fst->SetInputSymbols(new_isymbols); 155 } 156 157 vector<pair<Label, Label> > opairs; 158 if (old_osymbols && new_osymbols) { 159 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 160 syms_iter.Next()) { 161 string osymbol = syms_iter.Symbol(); 162 int osymbol_val = syms_iter.Value(); 163 int new_osymbol_val = new_osymbols->Find(osymbol); 164 opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); 165 } 166 if (attach_new_osymbols) 167 fst->SetOutputSymbols(new_osymbols); 168 } 169 170 // call relabel using vector of relabel pairs. 171 Relabel(fst, ipairs, opairs); 172 } 173 174 175 typedef CacheOptions RelabelFstOptions; 176 177 template <class A> class RelabelFst; 178 179 // 180 // \class RelabelFstImpl 181 // \brief Implementation for delayed relabeling 182 // 183 // Relabels an FST from one symbol set to another. Relabeling 184 // can either be on input or output space. RelabelFst implements 185 // a delayed version of the relabel. Arcs are relabeled on the fly 186 // and not cached. I.e each request is recomputed. 187 // 188 template<class A> 189 class RelabelFstImpl : public CacheImpl<A> { 190 friend class StateIterator< RelabelFst<A> >; 191 public: 192 using FstImpl<A>::SetType; 193 using FstImpl<A>::SetProperties; 194 using FstImpl<A>::WriteHeader; 195 using FstImpl<A>::SetInputSymbols; 196 using FstImpl<A>::SetOutputSymbols; 197 198 using CacheImpl<A>::PushArc; 199 using CacheImpl<A>::HasArcs; 200 using CacheImpl<A>::HasFinal; 201 using CacheImpl<A>::HasStart; 202 using CacheImpl<A>::SetArcs; 203 using CacheImpl<A>::SetFinal; 204 using CacheImpl<A>::SetStart; 205 206 typedef A Arc; 207 typedef typename A::Label Label; 208 typedef typename A::Weight Weight; 209 typedef typename A::StateId StateId; 210 typedef CacheState<A> State; 211 212 RelabelFstImpl(const Fst<A>& fst, 213 const vector<pair<Label, Label> >& ipairs, 214 const vector<pair<Label, Label> >& opairs, 215 const RelabelFstOptions &opts) 216 : CacheImpl<A>(opts), fst_(fst.Copy()), 217 relabel_input_(false), relabel_output_(false) { 218 uint64 props = fst.Properties(kCopyProperties, false); 219 SetProperties(RelabelProperties(props)); 220 SetType("relabel"); 221 222 // create input label map 223 if (ipairs.size() > 0) { 224 for (size_t i = 0; i < ipairs.size(); ++i) { 225 input_map_[ipairs[i].first] = ipairs[i].second; 226 } 227 relabel_input_ = true; 228 } 229 230 // create output label map 231 if (opairs.size() > 0) { 232 for (size_t i = 0; i < opairs.size(); ++i) { 233 output_map_[opairs[i].first] = opairs[i].second; 234 } 235 relabel_output_ = true; 236 } 237 } 238 239 RelabelFstImpl(const Fst<A>& fst, 240 const SymbolTable* old_isymbols, 241 const SymbolTable* new_isymbols, 242 const SymbolTable* old_osymbols, 243 const SymbolTable* new_osymbols, 244 const RelabelFstOptions &opts) 245 : CacheImpl<A>(opts), fst_(fst.Copy()), 246 relabel_input_(false), relabel_output_(false) { 247 SetType("relabel"); 248 249 uint64 props = fst.Properties(kCopyProperties, false); 250 SetProperties(RelabelProperties(props)); 251 SetInputSymbols(old_isymbols); 252 SetOutputSymbols(old_osymbols); 253 254 if (old_isymbols && new_isymbols && 255 old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { 256 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 257 syms_iter.Next()) { 258 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); 259 } 260 SetInputSymbols(new_isymbols); 261 relabel_input_ = true; 262 } 263 264 if (old_osymbols && new_osymbols && 265 old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { 266 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 267 syms_iter.Next()) { 268 output_map_[syms_iter.Value()] = 269 new_osymbols->Find(syms_iter.Symbol()); 270 } 271 SetOutputSymbols(new_osymbols); 272 relabel_output_ = true; 273 } 274 } 275 276 RelabelFstImpl(const RelabelFstImpl<A>& impl) 277 : CacheImpl<A>(impl), 278 fst_(impl.fst_->Copy(true)), 279 input_map_(impl.input_map_), 280 output_map_(impl.output_map_), 281 relabel_input_(impl.relabel_input_), 282 relabel_output_(impl.relabel_output_) { 283 SetType("relabel"); 284 SetProperties(impl.Properties(), kCopyProperties); 285 SetInputSymbols(impl.InputSymbols()); 286 SetOutputSymbols(impl.OutputSymbols()); 287 } 288 289 ~RelabelFstImpl() { delete fst_; } 290 291 StateId Start() { 292 if (!HasStart()) { 293 StateId s = fst_->Start(); 294 SetStart(s); 295 } 296 return CacheImpl<A>::Start(); 297 } 298 299 Weight Final(StateId s) { 300 if (!HasFinal(s)) { 301 SetFinal(s, fst_->Final(s)); 302 } 303 return CacheImpl<A>::Final(s); 304 } 305 306 size_t NumArcs(StateId s) { 307 if (!HasArcs(s)) { 308 Expand(s); 309 } 310 return CacheImpl<A>::NumArcs(s); 311 } 312 313 size_t NumInputEpsilons(StateId s) { 314 if (!HasArcs(s)) { 315 Expand(s); 316 } 317 return CacheImpl<A>::NumInputEpsilons(s); 318 } 319 320 size_t NumOutputEpsilons(StateId s) { 321 if (!HasArcs(s)) { 322 Expand(s); 323 } 324 return CacheImpl<A>::NumOutputEpsilons(s); 325 } 326 327 uint64 Properties() const { return Properties(kFstProperties); } 328 329 // Set error if found; return FST impl properties. 330 uint64 Properties(uint64 mask) const { 331 if ((mask & kError) && fst_->Properties(kError, false)) 332 SetProperties(kError, kError); 333 return FstImpl<Arc>::Properties(mask); 334 } 335 336 void InitArcIterator(StateId s, ArcIteratorData<A>* data) { 337 if (!HasArcs(s)) { 338 Expand(s); 339 } 340 CacheImpl<A>::InitArcIterator(s, data); 341 } 342 343 void Expand(StateId s) { 344 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { 345 A arc = aiter.Value(); 346 347 // relabel input 348 if (relabel_input_) { 349 typename unordered_map<Label, Label>::iterator it = 350 input_map_.find(arc.ilabel); 351 if (it != input_map_.end()) { arc.ilabel = it->second; } 352 } 353 354 // relabel output 355 if (relabel_output_) { 356 typename unordered_map<Label, Label>::iterator it = 357 output_map_.find(arc.olabel); 358 if (it != output_map_.end()) { arc.olabel = it->second; } 359 } 360 361 PushArc(s, arc); 362 } 363 SetArcs(s); 364 } 365 366 367 private: 368 const Fst<A> *fst_; 369 370 unordered_map<Label, Label> input_map_; 371 unordered_map<Label, Label> output_map_; 372 bool relabel_input_; 373 bool relabel_output_; 374 375 void operator=(const RelabelFstImpl<A> &); // disallow 376 }; 377 378 379 // 380 // \class RelabelFst 381 // \brief Delayed implementation of arc relabeling 382 // 383 // This class attaches interface to implementation and handles 384 // reference counting, delegating most methods to ImplToFst. 385 template <class A> 386 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { 387 public: 388 friend class ArcIterator< RelabelFst<A> >; 389 friend class StateIterator< RelabelFst<A> >; 390 391 typedef A Arc; 392 typedef typename A::Label Label; 393 typedef typename A::Weight Weight; 394 typedef typename A::StateId StateId; 395 typedef CacheState<A> State; 396 typedef RelabelFstImpl<A> Impl; 397 398 RelabelFst(const Fst<A>& fst, 399 const vector<pair<Label, Label> >& ipairs, 400 const vector<pair<Label, Label> >& opairs) 401 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} 402 403 RelabelFst(const Fst<A>& fst, 404 const vector<pair<Label, Label> >& ipairs, 405 const vector<pair<Label, Label> >& opairs, 406 const RelabelFstOptions &opts) 407 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} 408 409 RelabelFst(const Fst<A>& fst, 410 const SymbolTable* new_isymbols, 411 const SymbolTable* new_osymbols) 412 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, 413 fst.OutputSymbols(), new_osymbols, 414 RelabelFstOptions())) {} 415 416 RelabelFst(const Fst<A>& fst, 417 const SymbolTable* new_isymbols, 418 const SymbolTable* new_osymbols, 419 const RelabelFstOptions &opts) 420 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, 421 fst.OutputSymbols(), new_osymbols, opts)) {} 422 423 RelabelFst(const Fst<A>& fst, 424 const SymbolTable* old_isymbols, 425 const SymbolTable* new_isymbols, 426 const SymbolTable* old_osymbols, 427 const SymbolTable* new_osymbols) 428 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, 429 new_osymbols, RelabelFstOptions())) {} 430 431 RelabelFst(const Fst<A>& fst, 432 const SymbolTable* old_isymbols, 433 const SymbolTable* new_isymbols, 434 const SymbolTable* old_osymbols, 435 const SymbolTable* new_osymbols, 436 const RelabelFstOptions &opts) 437 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, 438 new_osymbols, opts)) {} 439 440 // See Fst<>::Copy() for doc. 441 RelabelFst(const RelabelFst<A> &fst, bool safe = false) 442 : ImplToFst<Impl>(fst, safe) {} 443 444 // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. 445 virtual RelabelFst<A> *Copy(bool safe = false) const { 446 return new RelabelFst<A>(*this, safe); 447 } 448 449 virtual void InitStateIterator(StateIteratorData<A> *data) const; 450 451 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 452 return GetImpl()->InitArcIterator(s, data); 453 } 454 455 private: 456 // Makes visible to friends. 457 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 458 459 void operator=(const RelabelFst<A> &fst); // disallow 460 }; 461 462 // Specialization for RelabelFst. 463 template<class A> 464 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { 465 public: 466 typedef typename A::StateId StateId; 467 468 explicit StateIterator(const RelabelFst<A> &fst) 469 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} 470 471 bool Done() const { return siter_.Done(); } 472 473 StateId Value() const { return s_; } 474 475 void Next() { 476 if (!siter_.Done()) { 477 ++s_; 478 siter_.Next(); 479 } 480 } 481 482 void Reset() { 483 s_ = 0; 484 siter_.Reset(); 485 } 486 487 private: 488 bool Done_() const { return Done(); } 489 StateId Value_() const { return Value(); } 490 void Next_() { Next(); } 491 void Reset_() { Reset(); } 492 493 const RelabelFstImpl<A> *impl_; 494 StateIterator< Fst<A> > siter_; 495 StateId s_; 496 497 DISALLOW_COPY_AND_ASSIGN(StateIterator); 498 }; 499 500 501 // Specialization for RelabelFst. 502 template <class A> 503 class ArcIterator< RelabelFst<A> > 504 : public CacheArcIterator< RelabelFst<A> > { 505 public: 506 typedef typename A::StateId StateId; 507 508 ArcIterator(const RelabelFst<A> &fst, StateId s) 509 : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { 510 if (!fst.GetImpl()->HasArcs(s)) 511 fst.GetImpl()->Expand(s); 512 } 513 514 private: 515 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 516 }; 517 518 template <class A> inline 519 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 520 data->base = new StateIterator< RelabelFst<A> >(*this); 521 } 522 523 // Useful alias when using StdArc. 524 typedef RelabelFst<StdArc> StdRelabelFst; 525 526 } // namespace fst 527 528 #endif // FST_LIB_RELABEL_H__ 529