1 // relabel.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Functions and classes to relabel an Fst (either on input or output) 18 // 19 #ifndef FST_LIB_RELABEL_H__ 20 #define FST_LIB_RELABEL_H__ 21 22 #include <unordered_map> 23 24 #include "fst/lib/cache.h" 25 #include "fst/lib/test-properties.h" 26 27 28 namespace fst { 29 30 // 31 // Relabels either the input labels or output labels. The old to 32 // new labels are specified using a vector of pair<Label,Label>. 33 // Any label associations not specified are assumed to be identity 34 // mapping. 35 // 36 // \param fst input fst, must be mutable 37 // \param relabel_pairs vector of pairs indicating old to new mapping 38 // \param relabel_flags whether to relabel input or output 39 // 40 template <class A> 41 void Relabel( 42 MutableFst<A> *fst, 43 const vector<pair<typename A::Label, typename A::Label> >& ipairs, 44 const vector<pair<typename A::Label, typename A::Label> >& opairs) { 45 typedef typename A::StateId StateId; 46 typedef typename A::Label Label; 47 48 uint64 props = fst->Properties(kFstProperties, false); 49 50 // construct label to label hash. Could 51 std::unordered_map<Label, Label> input_map; 52 for (size_t i = 0; i < ipairs.size(); ++i) { 53 input_map[ipairs[i].first] = ipairs[i].second; 54 } 55 56 std::unordered_map<Label, Label> output_map; 57 for (size_t i = 0; i < opairs.size(); ++i) { 58 output_map[opairs[i].first] = opairs[i].second; 59 } 60 61 for (StateIterator<MutableFst<A> > siter(*fst); 62 !siter.Done(); siter.Next()) { 63 StateId s = siter.Value(); 64 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 65 !aiter.Done(); aiter.Next()) { 66 A arc = aiter.Value(); 67 68 // relabel input 69 // only relabel if relabel pair defined 70 typename std::unordered_map<Label, Label>::iterator it = 71 input_map.find(arc.ilabel); 72 if (it != input_map.end()) {arc.ilabel = it->second; } 73 74 // relabel output 75 it = output_map.find(arc.olabel); 76 if (it != output_map.end()) { arc.olabel = it->second; } 77 78 aiter.SetValue(arc); 79 } 80 } 81 82 fst->SetProperties(RelabelProperties(props), kFstProperties); 83 } 84 85 86 87 // 88 // Relabels either the input labels or output labels. The old to 89 // new labels mappings are specified using an input Symbol set. 90 // Any label associations not specified are assumed to be identity 91 // mapping. 92 // 93 // \param fst input fst, must be mutable 94 // \param new_symbols symbol set indicating new mapping 95 // \param relabel_flags whether to relabel input or output 96 // 97 template<class A> 98 void Relabel(MutableFst<A> *fst, 99 const SymbolTable* new_isymbols, 100 const SymbolTable* new_osymbols) { 101 typedef typename A::StateId StateId; 102 typedef typename A::Label Label; 103 104 const SymbolTable* old_isymbols = fst->InputSymbols(); 105 const SymbolTable* old_osymbols = fst->OutputSymbols(); 106 107 vector<pair<Label, Label> > ipairs; 108 if (old_isymbols && new_isymbols) { 109 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 110 syms_iter.Next()) { 111 ipairs.push_back(make_pair(syms_iter.Value(), 112 new_isymbols->Find(syms_iter.Symbol()))); 113 } 114 fst->SetInputSymbols(new_isymbols); 115 } 116 117 vector<pair<Label, Label> > opairs; 118 if (old_osymbols && new_osymbols) { 119 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 120 syms_iter.Next()) { 121 opairs.push_back(make_pair(syms_iter.Value(), 122 new_osymbols->Find(syms_iter.Symbol()))); 123 } 124 fst->SetOutputSymbols(new_osymbols); 125 } 126 127 // call relabel using vector of relabel pairs. 128 Relabel(fst, ipairs, opairs); 129 } 130 131 132 typedef CacheOptions RelabelFstOptions; 133 134 template <class A> class RelabelFst; 135 136 // 137 // \class RelabelFstImpl 138 // \brief Implementation for delayed relabeling 139 // 140 // Relabels an FST from one symbol set to another. Relabeling 141 // can either be on input or output space. RelabelFst implements 142 // a delayed version of the relabel. Arcs are relabeled on the fly 143 // and not cached. I.e each request is recomputed. 144 // 145 template<class A> 146 class RelabelFstImpl : public CacheImpl<A> { 147 friend class StateIterator< RelabelFst<A> >; 148 public: 149 using FstImpl<A>::SetType; 150 using FstImpl<A>::SetProperties; 151 using FstImpl<A>::Properties; 152 using FstImpl<A>::SetInputSymbols; 153 using FstImpl<A>::SetOutputSymbols; 154 155 using CacheImpl<A>::HasStart; 156 using CacheImpl<A>::HasArcs; 157 158 typedef typename A::Label Label; 159 typedef typename A::Weight Weight; 160 typedef typename A::StateId StateId; 161 typedef CacheState<A> State; 162 163 RelabelFstImpl(const Fst<A>& fst, 164 const vector<pair<Label, Label> >& ipairs, 165 const vector<pair<Label, Label> >& opairs, 166 const RelabelFstOptions &opts) 167 : CacheImpl<A>(opts), fst_(fst.Copy()), 168 relabel_input_(false), relabel_output_(false) { 169 uint64 props = fst.Properties(kCopyProperties, false); 170 SetProperties(RelabelProperties(props)); 171 SetType("relabel"); 172 173 // create input label map 174 if (ipairs.size() > 0) { 175 for (size_t i = 0; i < ipairs.size(); ++i) { 176 input_map_[ipairs[i].first] = ipairs[i].second; 177 } 178 relabel_input_ = true; 179 } 180 181 // create output label map 182 if (opairs.size() > 0) { 183 for (size_t i = 0; i < opairs.size(); ++i) { 184 output_map_[opairs[i].first] = opairs[i].second; 185 } 186 relabel_output_ = true; 187 } 188 } 189 190 RelabelFstImpl(const Fst<A>& fst, 191 const SymbolTable* new_isymbols, 192 const SymbolTable* new_osymbols, 193 const RelabelFstOptions &opts) 194 : CacheImpl<A>(opts), fst_(fst.Copy()), 195 relabel_input_(false), relabel_output_(false) { 196 SetType("relabel"); 197 198 uint64 props = fst.Properties(kCopyProperties, false); 199 SetProperties(RelabelProperties(props)); 200 SetInputSymbols(fst.InputSymbols()); 201 SetOutputSymbols(fst.OutputSymbols()); 202 203 const SymbolTable* old_isymbols = fst.InputSymbols(); 204 const SymbolTable* old_osymbols = fst.OutputSymbols(); 205 206 if (old_isymbols && new_isymbols && 207 old_isymbols->CheckSum() != new_isymbols->CheckSum()) { 208 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 209 syms_iter.Next()) { 210 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); 211 } 212 SetInputSymbols(new_isymbols); 213 relabel_input_ = true; 214 } 215 216 if (old_osymbols && new_osymbols && 217 old_osymbols->CheckSum() != new_osymbols->CheckSum()) { 218 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 219 syms_iter.Next()) { 220 output_map_[syms_iter.Value()] = 221 new_osymbols->Find(syms_iter.Symbol()); 222 } 223 SetOutputSymbols(new_osymbols); 224 relabel_output_ = true; 225 } 226 } 227 228 ~RelabelFstImpl() { delete fst_; } 229 230 StateId Start() { 231 if (!HasStart()) { 232 StateId s = fst_->Start(); 233 SetStart(s); 234 } 235 return CacheImpl<A>::Start(); 236 } 237 238 Weight Final(StateId s) { 239 if (!HasFinal(s)) { 240 SetFinal(s, fst_->Final(s)); 241 } 242 return CacheImpl<A>::Final(s); 243 } 244 245 size_t NumArcs(StateId s) { 246 if (!HasArcs(s)) { 247 Expand(s); 248 } 249 return CacheImpl<A>::NumArcs(s); 250 } 251 252 size_t NumInputEpsilons(StateId s) { 253 if (!HasArcs(s)) { 254 Expand(s); 255 } 256 return CacheImpl<A>::NumInputEpsilons(s); 257 } 258 259 size_t NumOutputEpsilons(StateId s) { 260 if (!HasArcs(s)) { 261 Expand(s); 262 } 263 return CacheImpl<A>::NumOutputEpsilons(s); 264 } 265 266 void InitArcIterator(StateId s, ArcIteratorData<A>* data) { 267 if (!HasArcs(s)) { 268 Expand(s); 269 } 270 CacheImpl<A>::InitArcIterator(s, data); 271 } 272 273 void Expand(StateId s) { 274 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { 275 A arc = aiter.Value(); 276 277 // relabel input 278 if (relabel_input_) { 279 typename std::unordered_map<Label, Label>::iterator it = 280 input_map_.find(arc.ilabel); 281 if (it != input_map_.end()) { arc.ilabel = it->second; } 282 } 283 284 // relabel output 285 if (relabel_output_) { 286 typename std::unordered_map<Label, Label>::iterator it = 287 output_map_.find(arc.olabel); 288 if (it != output_map_.end()) { arc.olabel = it->second; } 289 } 290 291 AddArc(s, arc); 292 } 293 SetArcs(s); 294 } 295 296 297 private: 298 const Fst<A> *fst_; 299 300 std::unordered_map<Label, Label> input_map_; 301 std::unordered_map<Label, Label> output_map_; 302 bool relabel_input_; 303 bool relabel_output_; 304 305 DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl); 306 }; 307 308 309 // 310 // \class RelabelFst 311 // \brief Delayed implementation of arc relabeling 312 // 313 // This class attaches interface to implementation and handles 314 // reference counting. 315 template <class A> 316 class RelabelFst : public Fst<A> { 317 public: 318 friend class ArcIterator< RelabelFst<A> >; 319 friend class StateIterator< RelabelFst<A> >; 320 friend class CacheArcIterator< RelabelFst<A> >; 321 322 typedef A Arc; 323 typedef typename A::Label Label; 324 typedef typename A::Weight Weight; 325 typedef typename A::StateId StateId; 326 typedef CacheState<A> State; 327 328 RelabelFst(const Fst<A>& fst, 329 const vector<pair<Label, Label> >& ipairs, 330 const vector<pair<Label, Label> >& opairs) : 331 impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {} 332 333 RelabelFst(const Fst<A>& fst, 334 const vector<pair<Label, Label> >& ipairs, 335 const vector<pair<Label, Label> >& opairs, 336 const RelabelFstOptions &opts) 337 : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {} 338 339 RelabelFst(const Fst<A>& fst, 340 const SymbolTable* new_isymbols, 341 const SymbolTable* new_osymbols) : 342 impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, 343 RelabelFstOptions())) {} 344 345 RelabelFst(const Fst<A>& fst, 346 const SymbolTable* new_isymbols, 347 const SymbolTable* new_osymbols, 348 const RelabelFstOptions &opts) 349 : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {} 350 351 RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) { 352 impl_->IncrRefCount(); 353 } 354 355 virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_; } 356 357 virtual StateId Start() const { return impl_->Start(); } 358 359 virtual Weight Final(StateId s) const { return impl_->Final(s); } 360 361 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 362 363 virtual size_t NumInputEpsilons(StateId s) const { 364 return impl_->NumInputEpsilons(s); 365 } 366 367 virtual size_t NumOutputEpsilons(StateId s) const { 368 return impl_->NumOutputEpsilons(s); 369 } 370 371 virtual uint64 Properties(uint64 mask, bool test) const { 372 if (test) { 373 uint64 known, test = TestProperties(*this, mask, &known); 374 impl_->SetProperties(test, known); 375 return test & mask; 376 } else { 377 return impl_->Properties(mask); 378 } 379 } 380 381 virtual const string& Type() const { return impl_->Type(); } 382 383 virtual RelabelFst<A> *Copy() const { 384 return new RelabelFst<A>(*this); 385 } 386 387 virtual const SymbolTable* InputSymbols() const { 388 return impl_->InputSymbols(); 389 } 390 391 virtual const SymbolTable* OutputSymbols() const { 392 return impl_->OutputSymbols(); 393 } 394 395 virtual void InitStateIterator(StateIteratorData<A> *data) const; 396 397 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 398 return impl_->InitArcIterator(s, data); 399 } 400 401 private: 402 RelabelFstImpl<A> *impl_; 403 404 void operator=(const RelabelFst<A> &fst); // disallow 405 }; 406 407 // Specialization for RelabelFst. 408 template<class A> 409 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { 410 public: 411 typedef typename A::StateId StateId; 412 413 explicit StateIterator(const RelabelFst<A> &fst) 414 : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {} 415 416 bool Done() const { return siter_.Done(); } 417 418 StateId Value() const { return s_; } 419 420 void Next() { 421 if (!siter_.Done()) { 422 ++s_; 423 siter_.Next(); 424 } 425 } 426 427 void Reset() { 428 s_ = 0; 429 siter_.Reset(); 430 } 431 432 private: 433 const RelabelFstImpl<A> *impl_; 434 StateIterator< Fst<A> > siter_; 435 StateId s_; 436 437 DISALLOW_EVIL_CONSTRUCTORS(StateIterator); 438 }; 439 440 441 // Specialization for RelabelFst. 442 template <class A> 443 class ArcIterator< RelabelFst<A> > 444 : public CacheArcIterator< RelabelFst<A> > { 445 public: 446 typedef typename A::StateId StateId; 447 448 ArcIterator(const RelabelFst<A> &fst, StateId s) 449 : CacheArcIterator< RelabelFst<A> >(fst, s) { 450 if (!fst.impl_->HasArcs(s)) 451 fst.impl_->Expand(s); 452 } 453 454 private: 455 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 456 }; 457 458 template <class A> inline 459 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 460 data->base = new StateIterator< RelabelFst<A> >(*this); 461 } 462 463 // Useful alias when using StdArc. 464 typedef RelabelFst<StdArc> StdRelabelFst; 465 466 } // namespace fst 467 468 #endif // FST_LIB_RELABEL_H__ 469