Home | History | Annotate | Download | only in lib
      1 // relabel.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Functions and classes to relabel an Fst (either on input or output)
     18 //
     19 #ifndef FST_LIB_RELABEL_H__
     20 #define FST_LIB_RELABEL_H__
     21 
     22 #include <ext/hash_map>
     23 using __gnu_cxx::hash_map;
     24 
     25 #include "fst/lib/cache.h"
     26 #include "fst/lib/test-properties.h"
     27 
     28 
     29 namespace fst {
     30 
     31 //
     32 // Relabels either the input labels or output labels. The old to
     33 // new labels are specified using a vector of pair<Label,Label>.
     34 // Any label associations not specified are assumed to be identity
     35 // mapping.
     36 //
     37 // \param fst input fst, must be mutable
     38 // \param relabel_pairs vector of pairs indicating old to new mapping
     39 // \param relabel_flags whether to relabel input or output
     40 //
     41 template <class A>
     42 void Relabel(
     43     MutableFst<A> *fst,
     44     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
     45     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
     46   typedef typename A::StateId StateId;
     47   typedef typename A::Label   Label;
     48 
     49   uint64 props = fst->Properties(kFstProperties, false);
     50 
     51   // construct label to label hash. Could
     52   hash_map<Label, Label> input_map;
     53   for (size_t i = 0; i < ipairs.size(); ++i) {
     54     input_map[ipairs[i].first] = ipairs[i].second;
     55   }
     56 
     57   hash_map<Label, Label> output_map;
     58   for (size_t i = 0; i < opairs.size(); ++i) {
     59     output_map[opairs[i].first] = opairs[i].second;
     60   }
     61 
     62   for (StateIterator<MutableFst<A> > siter(*fst);
     63        !siter.Done(); siter.Next()) {
     64     StateId s = siter.Value();
     65     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
     66          !aiter.Done(); aiter.Next()) {
     67       A arc = aiter.Value();
     68 
     69       // relabel input
     70       // only relabel if relabel pair defined
     71       typename hash_map<Label, Label>::iterator it =
     72         input_map.find(arc.ilabel);
     73       if (it != input_map.end()) {arc.ilabel = it->second; }
     74 
     75       // relabel output
     76       it = output_map.find(arc.olabel);
     77       if (it != output_map.end()) { arc.olabel = it->second; }
     78 
     79       aiter.SetValue(arc);
     80     }
     81   }
     82 
     83   fst->SetProperties(RelabelProperties(props), kFstProperties);
     84 }
     85 
     86 
     87 
     88 //
     89 // Relabels either the input labels or output labels. The old to
     90 // new labels mappings are specified using an input Symbol set.
     91 // Any label associations not specified are assumed to be identity
     92 // mapping.
     93 //
     94 // \param fst input fst, must be mutable
     95 // \param new_symbols symbol set indicating new mapping
     96 // \param relabel_flags whether to relabel input or output
     97 //
     98 template<class A>
     99 void Relabel(MutableFst<A> *fst,
    100              const SymbolTable* new_isymbols,
    101              const SymbolTable* new_osymbols) {
    102   typedef typename A::StateId StateId;
    103   typedef typename A::Label   Label;
    104 
    105   const SymbolTable* old_isymbols = fst->InputSymbols();
    106   const SymbolTable* old_osymbols = fst->OutputSymbols();
    107 
    108   vector<pair<Label, Label> > ipairs;
    109   if (old_isymbols && new_isymbols) {
    110     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    111          syms_iter.Next()) {
    112       ipairs.push_back(make_pair(syms_iter.Value(),
    113                                  new_isymbols->Find(syms_iter.Symbol())));
    114     }
    115     fst->SetInputSymbols(new_isymbols);
    116   }
    117 
    118   vector<pair<Label, Label> > opairs;
    119   if (old_osymbols && new_osymbols) {
    120     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    121          syms_iter.Next()) {
    122       opairs.push_back(make_pair(syms_iter.Value(),
    123                                  new_osymbols->Find(syms_iter.Symbol())));
    124     }
    125     fst->SetOutputSymbols(new_osymbols);
    126   }
    127 
    128   // call relabel using vector of relabel pairs.
    129   Relabel(fst, ipairs, opairs);
    130 }
    131 
    132 
    133 typedef CacheOptions RelabelFstOptions;
    134 
    135 template <class A> class RelabelFst;
    136 
    137 //
    138 // \class RelabelFstImpl
    139 // \brief Implementation for delayed relabeling
    140 //
    141 // Relabels an FST from one symbol set to another. Relabeling
    142 // can either be on input or output space. RelabelFst implements
    143 // a delayed version of the relabel. Arcs are relabeled on the fly
    144 // and not cached. I.e each request is recomputed.
    145 //
    146 template<class A>
    147 class RelabelFstImpl : public CacheImpl<A> {
    148   friend class StateIterator< RelabelFst<A> >;
    149  public:
    150   using FstImpl<A>::SetType;
    151   using FstImpl<A>::SetProperties;
    152   using FstImpl<A>::Properties;
    153   using FstImpl<A>::SetInputSymbols;
    154   using FstImpl<A>::SetOutputSymbols;
    155 
    156   using CacheImpl<A>::HasStart;
    157   using CacheImpl<A>::HasArcs;
    158 
    159   typedef typename A::Label   Label;
    160   typedef typename A::Weight  Weight;
    161   typedef typename A::StateId StateId;
    162   typedef CacheState<A> State;
    163 
    164   RelabelFstImpl(const Fst<A>& fst,
    165                  const vector<pair<Label, Label> >& ipairs,
    166                  const vector<pair<Label, Label> >& opairs,
    167                  const RelabelFstOptions &opts)
    168       : CacheImpl<A>(opts), fst_(fst.Copy()),
    169         relabel_input_(false), relabel_output_(false) {
    170     uint64 props = fst.Properties(kCopyProperties, false);
    171     SetProperties(RelabelProperties(props));
    172     SetType("relabel");
    173 
    174     // create input label map
    175     if (ipairs.size() > 0) {
    176       for (size_t i = 0; i < ipairs.size(); ++i) {
    177         input_map_[ipairs[i].first] = ipairs[i].second;
    178       }
    179       relabel_input_ = true;
    180     }
    181 
    182     // create output label map
    183     if (opairs.size() > 0) {
    184       for (size_t i = 0; i < opairs.size(); ++i) {
    185         output_map_[opairs[i].first] = opairs[i].second;
    186       }
    187       relabel_output_ = true;
    188     }
    189   }
    190 
    191   RelabelFstImpl(const Fst<A>& fst,
    192                  const SymbolTable* new_isymbols,
    193                  const SymbolTable* new_osymbols,
    194                  const RelabelFstOptions &opts)
    195       : CacheImpl<A>(opts), fst_(fst.Copy()),
    196         relabel_input_(false), relabel_output_(false) {
    197     SetType("relabel");
    198 
    199     uint64 props = fst.Properties(kCopyProperties, false);
    200     SetProperties(RelabelProperties(props));
    201     SetInputSymbols(fst.InputSymbols());
    202     SetOutputSymbols(fst.OutputSymbols());
    203 
    204     const SymbolTable* old_isymbols = fst.InputSymbols();
    205     const SymbolTable* old_osymbols = fst.OutputSymbols();
    206 
    207     if (old_isymbols && new_isymbols &&
    208         old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
    209       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    210            syms_iter.Next()) {
    211         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
    212       }
    213       SetInputSymbols(new_isymbols);
    214       relabel_input_ = true;
    215     }
    216 
    217     if (old_osymbols && new_osymbols &&
    218         old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
    219       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    220            syms_iter.Next()) {
    221         output_map_[syms_iter.Value()] =
    222           new_osymbols->Find(syms_iter.Symbol());
    223       }
    224       SetOutputSymbols(new_osymbols);
    225       relabel_output_ = true;
    226     }
    227   }
    228 
    229   ~RelabelFstImpl() { delete fst_; }
    230 
    231   StateId Start() {
    232     if (!HasStart()) {
    233       StateId s = fst_->Start();
    234       SetStart(s);
    235     }
    236     return CacheImpl<A>::Start();
    237   }
    238 
    239   Weight Final(StateId s) {
    240     if (!HasFinal(s)) {
    241       SetFinal(s, fst_->Final(s));
    242     }
    243     return CacheImpl<A>::Final(s);
    244   }
    245 
    246   size_t NumArcs(StateId s) {
    247     if (!HasArcs(s)) {
    248       Expand(s);
    249     }
    250     return CacheImpl<A>::NumArcs(s);
    251   }
    252 
    253   size_t NumInputEpsilons(StateId s) {
    254     if (!HasArcs(s)) {
    255       Expand(s);
    256     }
    257     return CacheImpl<A>::NumInputEpsilons(s);
    258   }
    259 
    260   size_t NumOutputEpsilons(StateId s) {
    261     if (!HasArcs(s)) {
    262       Expand(s);
    263     }
    264     return CacheImpl<A>::NumOutputEpsilons(s);
    265   }
    266 
    267   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
    268     if (!HasArcs(s)) {
    269       Expand(s);
    270     }
    271     CacheImpl<A>::InitArcIterator(s, data);
    272   }
    273 
    274   void Expand(StateId s) {
    275     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
    276       A arc = aiter.Value();
    277 
    278       // relabel input
    279       if (relabel_input_) {
    280         typename hash_map<Label, Label>::iterator it =
    281           input_map_.find(arc.ilabel);
    282         if (it != input_map_.end()) { arc.ilabel = it->second; }
    283       }
    284 
    285       // relabel output
    286       if (relabel_output_) {
    287         typename hash_map<Label, Label>::iterator it =
    288           output_map_.find(arc.olabel);
    289         if (it != output_map_.end()) { arc.olabel = it->second; }
    290       }
    291 
    292       AddArc(s, arc);
    293     }
    294     SetArcs(s);
    295   }
    296 
    297 
    298  private:
    299   const Fst<A> *fst_;
    300 
    301   hash_map<Label, Label> input_map_;
    302   hash_map<Label, Label> output_map_;
    303   bool relabel_input_;
    304   bool relabel_output_;
    305 
    306   DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
    307 };
    308 
    309 
    310 //
    311 // \class RelabelFst
    312 // \brief Delayed implementation of arc relabeling
    313 //
    314 // This class attaches interface to implementation and handles
    315 // reference counting.
    316 template <class A>
    317 class RelabelFst : public Fst<A> {
    318  public:
    319   friend class ArcIterator< RelabelFst<A> >;
    320   friend class StateIterator< RelabelFst<A> >;
    321   friend class CacheArcIterator< RelabelFst<A> >;
    322 
    323   typedef A Arc;
    324   typedef typename A::Label   Label;
    325   typedef typename A::Weight  Weight;
    326   typedef typename A::StateId StateId;
    327   typedef CacheState<A> State;
    328 
    329   RelabelFst(const Fst<A>& fst,
    330              const vector<pair<Label, Label> >& ipairs,
    331              const vector<pair<Label, Label> >& opairs) :
    332       impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
    333 
    334   RelabelFst(const Fst<A>& fst,
    335              const vector<pair<Label, Label> >& ipairs,
    336              const vector<pair<Label, Label> >& opairs,
    337              const RelabelFstOptions &opts)
    338       : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
    339 
    340   RelabelFst(const Fst<A>& fst,
    341              const SymbolTable* new_isymbols,
    342              const SymbolTable* new_osymbols) :
    343       impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
    344                                   RelabelFstOptions())) {}
    345 
    346   RelabelFst(const Fst<A>& fst,
    347              const SymbolTable* new_isymbols,
    348              const SymbolTable* new_osymbols,
    349              const RelabelFstOptions &opts)
    350     : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
    351 
    352   RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
    353     impl_->IncrRefCount();
    354   }
    355 
    356   virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    357 
    358   virtual StateId Start() const { return impl_->Start(); }
    359 
    360   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    361 
    362   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    363 
    364   virtual size_t NumInputEpsilons(StateId s) const {
    365     return impl_->NumInputEpsilons(s);
    366   }
    367 
    368   virtual size_t NumOutputEpsilons(StateId s) const {
    369     return impl_->NumOutputEpsilons(s);
    370   }
    371 
    372   virtual uint64 Properties(uint64 mask, bool test) const {
    373     if (test) {
    374       uint64 known, test = TestProperties(*this, mask, &known);
    375       impl_->SetProperties(test, known);
    376       return test & mask;
    377     } else {
    378       return impl_->Properties(mask);
    379     }
    380   }
    381 
    382   virtual const string& Type() const { return impl_->Type(); }
    383 
    384   virtual RelabelFst<A> *Copy() const {
    385     return new RelabelFst<A>(*this);
    386   }
    387 
    388   virtual const SymbolTable* InputSymbols() const {
    389     return impl_->InputSymbols();
    390   }
    391 
    392   virtual const SymbolTable* OutputSymbols() const {
    393     return impl_->OutputSymbols();
    394   }
    395 
    396   virtual void InitStateIterator(StateIteratorData<A> *data) const;
    397 
    398   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    399     return impl_->InitArcIterator(s, data);
    400   }
    401 
    402  private:
    403   RelabelFstImpl<A> *impl_;
    404 
    405   void operator=(const RelabelFst<A> &fst);  // disallow
    406 };
    407 
    408 // Specialization for RelabelFst.
    409 template<class A>
    410 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
    411  public:
    412   typedef typename A::StateId StateId;
    413 
    414   explicit StateIterator(const RelabelFst<A> &fst)
    415       : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
    416 
    417   bool Done() const { return siter_.Done(); }
    418 
    419   StateId Value() const { return s_; }
    420 
    421   void Next() {
    422     if (!siter_.Done()) {
    423       ++s_;
    424       siter_.Next();
    425     }
    426   }
    427 
    428   void Reset() {
    429     s_ = 0;
    430     siter_.Reset();
    431   }
    432 
    433  private:
    434   const RelabelFstImpl<A> *impl_;
    435   StateIterator< Fst<A> > siter_;
    436   StateId s_;
    437 
    438   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    439 };
    440 
    441 
    442 // Specialization for RelabelFst.
    443 template <class A>
    444 class ArcIterator< RelabelFst<A> >
    445     : public CacheArcIterator< RelabelFst<A> > {
    446  public:
    447   typedef typename A::StateId StateId;
    448 
    449   ArcIterator(const RelabelFst<A> &fst, StateId s)
    450       : CacheArcIterator< RelabelFst<A> >(fst, s) {
    451     if (!fst.impl_->HasArcs(s))
    452       fst.impl_->Expand(s);
    453   }
    454 
    455  private:
    456   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    457 };
    458 
    459 template <class A> inline
    460 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    461   data->base = new StateIterator< RelabelFst<A> >(*this);
    462 }
    463 
    464 // Useful alias when using StdArc.
    465 typedef RelabelFst<StdArc> StdRelabelFst;
    466 
    467 }  // namespace fst
    468 
    469 #endif  // FST_LIB_RELABEL_H__
    470