Home | History | Annotate | Download | only in lib
      1 // relabel.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Functions and classes to relabel an Fst (either on input or output)
     18 //
     19 #ifndef FST_LIB_RELABEL_H__
     20 #define FST_LIB_RELABEL_H__
     21 
     22 #include <unordered_map>
     23 
     24 #include "fst/lib/cache.h"
     25 #include "fst/lib/test-properties.h"
     26 
     27 
     28 namespace fst {
     29 
     30 //
     31 // Relabels either the input labels or output labels. The old to
     32 // new labels are specified using a vector of pair<Label,Label>.
     33 // Any label associations not specified are assumed to be identity
     34 // mapping.
     35 //
     36 // \param fst input fst, must be mutable
     37 // \param relabel_pairs vector of pairs indicating old to new mapping
     38 // \param relabel_flags whether to relabel input or output
     39 //
     40 template <class A>
     41 void Relabel(
     42     MutableFst<A> *fst,
     43     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
     44     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
     45   typedef typename A::StateId StateId;
     46   typedef typename A::Label   Label;
     47 
     48   uint64 props = fst->Properties(kFstProperties, false);
     49 
     50   // construct label to label hash. Could
     51   std::unordered_map<Label, Label> input_map;
     52   for (size_t i = 0; i < ipairs.size(); ++i) {
     53     input_map[ipairs[i].first] = ipairs[i].second;
     54   }
     55 
     56   std::unordered_map<Label, Label> output_map;
     57   for (size_t i = 0; i < opairs.size(); ++i) {
     58     output_map[opairs[i].first] = opairs[i].second;
     59   }
     60 
     61   for (StateIterator<MutableFst<A> > siter(*fst);
     62        !siter.Done(); siter.Next()) {
     63     StateId s = siter.Value();
     64     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
     65          !aiter.Done(); aiter.Next()) {
     66       A arc = aiter.Value();
     67 
     68       // relabel input
     69       // only relabel if relabel pair defined
     70       typename std::unordered_map<Label, Label>::iterator it =
     71         input_map.find(arc.ilabel);
     72       if (it != input_map.end()) {arc.ilabel = it->second; }
     73 
     74       // relabel output
     75       it = output_map.find(arc.olabel);
     76       if (it != output_map.end()) { arc.olabel = it->second; }
     77 
     78       aiter.SetValue(arc);
     79     }
     80   }
     81 
     82   fst->SetProperties(RelabelProperties(props), kFstProperties);
     83 }
     84 
     85 
     86 
     87 //
     88 // Relabels either the input labels or output labels. The old to
     89 // new labels mappings are specified using an input Symbol set.
     90 // Any label associations not specified are assumed to be identity
     91 // mapping.
     92 //
     93 // \param fst input fst, must be mutable
     94 // \param new_symbols symbol set indicating new mapping
     95 // \param relabel_flags whether to relabel input or output
     96 //
     97 template<class A>
     98 void Relabel(MutableFst<A> *fst,
     99              const SymbolTable* new_isymbols,
    100              const SymbolTable* new_osymbols) {
    101   typedef typename A::StateId StateId;
    102   typedef typename A::Label   Label;
    103 
    104   const SymbolTable* old_isymbols = fst->InputSymbols();
    105   const SymbolTable* old_osymbols = fst->OutputSymbols();
    106 
    107   vector<pair<Label, Label> > ipairs;
    108   if (old_isymbols && new_isymbols) {
    109     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    110          syms_iter.Next()) {
    111       ipairs.push_back(make_pair(syms_iter.Value(),
    112                                  new_isymbols->Find(syms_iter.Symbol())));
    113     }
    114     fst->SetInputSymbols(new_isymbols);
    115   }
    116 
    117   vector<pair<Label, Label> > opairs;
    118   if (old_osymbols && new_osymbols) {
    119     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    120          syms_iter.Next()) {
    121       opairs.push_back(make_pair(syms_iter.Value(),
    122                                  new_osymbols->Find(syms_iter.Symbol())));
    123     }
    124     fst->SetOutputSymbols(new_osymbols);
    125   }
    126 
    127   // call relabel using vector of relabel pairs.
    128   Relabel(fst, ipairs, opairs);
    129 }
    130 
    131 
    132 typedef CacheOptions RelabelFstOptions;
    133 
    134 template <class A> class RelabelFst;
    135 
    136 //
    137 // \class RelabelFstImpl
    138 // \brief Implementation for delayed relabeling
    139 //
    140 // Relabels an FST from one symbol set to another. Relabeling
    141 // can either be on input or output space. RelabelFst implements
    142 // a delayed version of the relabel. Arcs are relabeled on the fly
    143 // and not cached. I.e each request is recomputed.
    144 //
    145 template<class A>
    146 class RelabelFstImpl : public CacheImpl<A> {
    147   friend class StateIterator< RelabelFst<A> >;
    148  public:
    149   using FstImpl<A>::SetType;
    150   using FstImpl<A>::SetProperties;
    151   using FstImpl<A>::Properties;
    152   using FstImpl<A>::SetInputSymbols;
    153   using FstImpl<A>::SetOutputSymbols;
    154 
    155   using CacheImpl<A>::HasStart;
    156   using CacheImpl<A>::HasArcs;
    157 
    158   typedef typename A::Label   Label;
    159   typedef typename A::Weight  Weight;
    160   typedef typename A::StateId StateId;
    161   typedef CacheState<A> State;
    162 
    163   RelabelFstImpl(const Fst<A>& fst,
    164                  const vector<pair<Label, Label> >& ipairs,
    165                  const vector<pair<Label, Label> >& opairs,
    166                  const RelabelFstOptions &opts)
    167       : CacheImpl<A>(opts), fst_(fst.Copy()),
    168         relabel_input_(false), relabel_output_(false) {
    169     uint64 props = fst.Properties(kCopyProperties, false);
    170     SetProperties(RelabelProperties(props));
    171     SetType("relabel");
    172 
    173     // create input label map
    174     if (ipairs.size() > 0) {
    175       for (size_t i = 0; i < ipairs.size(); ++i) {
    176         input_map_[ipairs[i].first] = ipairs[i].second;
    177       }
    178       relabel_input_ = true;
    179     }
    180 
    181     // create output label map
    182     if (opairs.size() > 0) {
    183       for (size_t i = 0; i < opairs.size(); ++i) {
    184         output_map_[opairs[i].first] = opairs[i].second;
    185       }
    186       relabel_output_ = true;
    187     }
    188   }
    189 
    190   RelabelFstImpl(const Fst<A>& fst,
    191                  const SymbolTable* new_isymbols,
    192                  const SymbolTable* new_osymbols,
    193                  const RelabelFstOptions &opts)
    194       : CacheImpl<A>(opts), fst_(fst.Copy()),
    195         relabel_input_(false), relabel_output_(false) {
    196     SetType("relabel");
    197 
    198     uint64 props = fst.Properties(kCopyProperties, false);
    199     SetProperties(RelabelProperties(props));
    200     SetInputSymbols(fst.InputSymbols());
    201     SetOutputSymbols(fst.OutputSymbols());
    202 
    203     const SymbolTable* old_isymbols = fst.InputSymbols();
    204     const SymbolTable* old_osymbols = fst.OutputSymbols();
    205 
    206     if (old_isymbols && new_isymbols &&
    207         old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
    208       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    209            syms_iter.Next()) {
    210         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
    211       }
    212       SetInputSymbols(new_isymbols);
    213       relabel_input_ = true;
    214     }
    215 
    216     if (old_osymbols && new_osymbols &&
    217         old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
    218       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    219            syms_iter.Next()) {
    220         output_map_[syms_iter.Value()] =
    221           new_osymbols->Find(syms_iter.Symbol());
    222       }
    223       SetOutputSymbols(new_osymbols);
    224       relabel_output_ = true;
    225     }
    226   }
    227 
    228   ~RelabelFstImpl() { delete fst_; }
    229 
    230   StateId Start() {
    231     if (!HasStart()) {
    232       StateId s = fst_->Start();
    233       SetStart(s);
    234     }
    235     return CacheImpl<A>::Start();
    236   }
    237 
    238   Weight Final(StateId s) {
    239     if (!HasFinal(s)) {
    240       SetFinal(s, fst_->Final(s));
    241     }
    242     return CacheImpl<A>::Final(s);
    243   }
    244 
    245   size_t NumArcs(StateId s) {
    246     if (!HasArcs(s)) {
    247       Expand(s);
    248     }
    249     return CacheImpl<A>::NumArcs(s);
    250   }
    251 
    252   size_t NumInputEpsilons(StateId s) {
    253     if (!HasArcs(s)) {
    254       Expand(s);
    255     }
    256     return CacheImpl<A>::NumInputEpsilons(s);
    257   }
    258 
    259   size_t NumOutputEpsilons(StateId s) {
    260     if (!HasArcs(s)) {
    261       Expand(s);
    262     }
    263     return CacheImpl<A>::NumOutputEpsilons(s);
    264   }
    265 
    266   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
    267     if (!HasArcs(s)) {
    268       Expand(s);
    269     }
    270     CacheImpl<A>::InitArcIterator(s, data);
    271   }
    272 
    273   void Expand(StateId s) {
    274     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
    275       A arc = aiter.Value();
    276 
    277       // relabel input
    278       if (relabel_input_) {
    279         typename std::unordered_map<Label, Label>::iterator it =
    280           input_map_.find(arc.ilabel);
    281         if (it != input_map_.end()) { arc.ilabel = it->second; }
    282       }
    283 
    284       // relabel output
    285       if (relabel_output_) {
    286         typename std::unordered_map<Label, Label>::iterator it =
    287           output_map_.find(arc.olabel);
    288         if (it != output_map_.end()) { arc.olabel = it->second; }
    289       }
    290 
    291       AddArc(s, arc);
    292     }
    293     SetArcs(s);
    294   }
    295 
    296 
    297  private:
    298   const Fst<A> *fst_;
    299 
    300   std::unordered_map<Label, Label> input_map_;
    301   std::unordered_map<Label, Label> output_map_;
    302   bool relabel_input_;
    303   bool relabel_output_;
    304 
    305   DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
    306 };
    307 
    308 
    309 //
    310 // \class RelabelFst
    311 // \brief Delayed implementation of arc relabeling
    312 //
    313 // This class attaches interface to implementation and handles
    314 // reference counting.
    315 template <class A>
    316 class RelabelFst : public Fst<A> {
    317  public:
    318   friend class ArcIterator< RelabelFst<A> >;
    319   friend class StateIterator< RelabelFst<A> >;
    320   friend class CacheArcIterator< RelabelFst<A> >;
    321 
    322   typedef A Arc;
    323   typedef typename A::Label   Label;
    324   typedef typename A::Weight  Weight;
    325   typedef typename A::StateId StateId;
    326   typedef CacheState<A> State;
    327 
    328   RelabelFst(const Fst<A>& fst,
    329              const vector<pair<Label, Label> >& ipairs,
    330              const vector<pair<Label, Label> >& opairs) :
    331       impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
    332 
    333   RelabelFst(const Fst<A>& fst,
    334              const vector<pair<Label, Label> >& ipairs,
    335              const vector<pair<Label, Label> >& opairs,
    336              const RelabelFstOptions &opts)
    337       : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
    338 
    339   RelabelFst(const Fst<A>& fst,
    340              const SymbolTable* new_isymbols,
    341              const SymbolTable* new_osymbols) :
    342       impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
    343                                   RelabelFstOptions())) {}
    344 
    345   RelabelFst(const Fst<A>& fst,
    346              const SymbolTable* new_isymbols,
    347              const SymbolTable* new_osymbols,
    348              const RelabelFstOptions &opts)
    349     : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
    350 
    351   RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
    352     impl_->IncrRefCount();
    353   }
    354 
    355   virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    356 
    357   virtual StateId Start() const { return impl_->Start(); }
    358 
    359   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    360 
    361   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    362 
    363   virtual size_t NumInputEpsilons(StateId s) const {
    364     return impl_->NumInputEpsilons(s);
    365   }
    366 
    367   virtual size_t NumOutputEpsilons(StateId s) const {
    368     return impl_->NumOutputEpsilons(s);
    369   }
    370 
    371   virtual uint64 Properties(uint64 mask, bool test) const {
    372     if (test) {
    373       uint64 known, test = TestProperties(*this, mask, &known);
    374       impl_->SetProperties(test, known);
    375       return test & mask;
    376     } else {
    377       return impl_->Properties(mask);
    378     }
    379   }
    380 
    381   virtual const string& Type() const { return impl_->Type(); }
    382 
    383   virtual RelabelFst<A> *Copy() const {
    384     return new RelabelFst<A>(*this);
    385   }
    386 
    387   virtual const SymbolTable* InputSymbols() const {
    388     return impl_->InputSymbols();
    389   }
    390 
    391   virtual const SymbolTable* OutputSymbols() const {
    392     return impl_->OutputSymbols();
    393   }
    394 
    395   virtual void InitStateIterator(StateIteratorData<A> *data) const;
    396 
    397   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    398     return impl_->InitArcIterator(s, data);
    399   }
    400 
    401  private:
    402   RelabelFstImpl<A> *impl_;
    403 
    404   void operator=(const RelabelFst<A> &fst);  // disallow
    405 };
    406 
    407 // Specialization for RelabelFst.
    408 template<class A>
    409 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
    410  public:
    411   typedef typename A::StateId StateId;
    412 
    413   explicit StateIterator(const RelabelFst<A> &fst)
    414       : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
    415 
    416   bool Done() const { return siter_.Done(); }
    417 
    418   StateId Value() const { return s_; }
    419 
    420   void Next() {
    421     if (!siter_.Done()) {
    422       ++s_;
    423       siter_.Next();
    424     }
    425   }
    426 
    427   void Reset() {
    428     s_ = 0;
    429     siter_.Reset();
    430   }
    431 
    432  private:
    433   const RelabelFstImpl<A> *impl_;
    434   StateIterator< Fst<A> > siter_;
    435   StateId s_;
    436 
    437   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    438 };
    439 
    440 
    441 // Specialization for RelabelFst.
    442 template <class A>
    443 class ArcIterator< RelabelFst<A> >
    444     : public CacheArcIterator< RelabelFst<A> > {
    445  public:
    446   typedef typename A::StateId StateId;
    447 
    448   ArcIterator(const RelabelFst<A> &fst, StateId s)
    449       : CacheArcIterator< RelabelFst<A> >(fst, s) {
    450     if (!fst.impl_->HasArcs(s))
    451       fst.impl_->Expand(s);
    452   }
    453 
    454  private:
    455   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    456 };
    457 
    458 template <class A> inline
    459 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    460   data->base = new StateIterator< RelabelFst<A> >(*this);
    461 }
    462 
    463 // Useful alias when using StdArc.
    464 typedef RelabelFst<StdArc> StdRelabelFst;
    465 
    466 }  // namespace fst
    467 
    468 #endif  // FST_LIB_RELABEL_H__
    469