Home | History | Annotate | Download | only in fst
      1 // relabel.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: johans (at) google.com (Johan Schalkwyk)
     17 //
     18 // \file
     19 // Functions and classes to relabel an Fst (either on input or output)
     20 //
     21 #ifndef FST_LIB_RELABEL_H__
     22 #define FST_LIB_RELABEL_H__
     23 
     24 #include <unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <string>
     28 #include <utility>
     29 using std::pair; using std::make_pair;
     30 #include <vector>
     31 using std::vector;
     32 
     33 #include <fst/cache.h>
     34 #include <fst/test-properties.h>
     35 
     36 
     37 namespace fst {
     38 
     39 //
     40 // Relabels either the input labels or output labels. The old to
     41 // new labels are specified using a vector of pair<Label,Label>.
     42 // Any label associations not specified are assumed to be identity
     43 // mapping.
     44 //
     45 // \param fst input fst, must be mutable
     46 // \param ipairs vector of input label pairs indicating old to new mapping
     47 // \param opairs vector of output label pairs indicating old to new mapping
     48 //
     49 template <class A>
     50 void Relabel(
     51     MutableFst<A> *fst,
     52     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
     53     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
     54   typedef typename A::StateId StateId;
     55   typedef typename A::Label   Label;
     56 
     57   uint64 props = fst->Properties(kFstProperties, false);
     58 
     59   // construct label to label hash.
     60   unordered_map<Label, Label> input_map;
     61   for (size_t i = 0; i < ipairs.size(); ++i) {
     62     input_map[ipairs[i].first] = ipairs[i].second;
     63   }
     64 
     65   unordered_map<Label, Label> output_map;
     66   for (size_t i = 0; i < opairs.size(); ++i) {
     67     output_map[opairs[i].first] = opairs[i].second;
     68   }
     69 
     70   for (StateIterator<MutableFst<A> > siter(*fst);
     71        !siter.Done(); siter.Next()) {
     72     StateId s = siter.Value();
     73     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
     74          !aiter.Done(); aiter.Next()) {
     75       A arc = aiter.Value();
     76 
     77       // relabel input
     78       // only relabel if relabel pair defined
     79       typename unordered_map<Label, Label>::iterator it =
     80         input_map.find(arc.ilabel);
     81       if (it != input_map.end()) {
     82         if (it->second == kNoLabel) {
     83           FSTERROR() << "Input symbol id " << arc.ilabel
     84                      << " missing from target vocabulary";
     85           fst->SetProperties(kError, kError);
     86           return;
     87         }
     88         arc.ilabel = it->second;
     89       }
     90 
     91       // relabel output
     92       it = output_map.find(arc.olabel);
     93       if (it != output_map.end()) {
     94         if (it->second == kNoLabel) {
     95           FSTERROR() << "Output symbol id " << arc.olabel
     96                      << " missing from target vocabulary";
     97           fst->SetProperties(kError, kError);
     98           return;
     99         }
    100         arc.olabel = it->second;
    101       }
    102 
    103       aiter.SetValue(arc);
    104     }
    105   }
    106 
    107   fst->SetProperties(RelabelProperties(props), kFstProperties);
    108 }
    109 
    110 //
    111 // Relabels either the input labels or output labels. The old to
    112 // new labels mappings are specified using an input Symbol set.
    113 // Any label associations not specified are assumed to be identity
    114 // mapping.
    115 //
    116 // \param fst input fst, must be mutable
    117 // \param new_isymbols symbol set indicating new mapping of input symbols
    118 // \param new_osymbols symbol set indicating new mapping of output symbols
    119 //
    120 template<class A>
    121 void Relabel(MutableFst<A> *fst,
    122              const SymbolTable* new_isymbols,
    123              const SymbolTable* new_osymbols) {
    124   Relabel(fst,
    125           fst->InputSymbols(), new_isymbols, true,
    126           fst->OutputSymbols(), new_osymbols, true);
    127 }
    128 
    129 template<class A>
    130 void Relabel(MutableFst<A> *fst,
    131              const SymbolTable* old_isymbols,
    132              const SymbolTable* new_isymbols,
    133              bool attach_new_isymbols,
    134              const SymbolTable* old_osymbols,
    135              const SymbolTable* new_osymbols,
    136              bool attach_new_osymbols) {
    137   typedef typename A::StateId StateId;
    138   typedef typename A::Label   Label;
    139 
    140   vector<pair<Label, Label> > ipairs;
    141   if (old_isymbols && new_isymbols) {
    142     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    143          syms_iter.Next()) {
    144       string isymbol = syms_iter.Symbol();
    145       int isymbol_val = syms_iter.Value();
    146       int new_isymbol_val = new_isymbols->Find(isymbol);
    147       ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
    148     }
    149     if (attach_new_isymbols)
    150       fst->SetInputSymbols(new_isymbols);
    151   }
    152 
    153   vector<pair<Label, Label> > opairs;
    154   if (old_osymbols && new_osymbols) {
    155     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    156          syms_iter.Next()) {
    157       string osymbol = syms_iter.Symbol();
    158       int osymbol_val = syms_iter.Value();
    159       int new_osymbol_val = new_osymbols->Find(osymbol);
    160       opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
    161     }
    162     if (attach_new_osymbols)
    163       fst->SetOutputSymbols(new_osymbols);
    164   }
    165 
    166   // call relabel using vector of relabel pairs.
    167   Relabel(fst, ipairs, opairs);
    168 }
    169 
    170 
    171 typedef CacheOptions RelabelFstOptions;
    172 
    173 template <class A> class RelabelFst;
    174 
    175 //
    176 // \class RelabelFstImpl
    177 // \brief Implementation for delayed relabeling
    178 //
    179 // Relabels an FST from one symbol set to another. Relabeling
    180 // can either be on input or output space. RelabelFst implements
    181 // a delayed version of the relabel. Arcs are relabeled on the fly
    182 // and not cached. I.e each request is recomputed.
    183 //
    184 template<class A>
    185 class RelabelFstImpl : public CacheImpl<A> {
    186   friend class StateIterator< RelabelFst<A> >;
    187  public:
    188   using FstImpl<A>::SetType;
    189   using FstImpl<A>::SetProperties;
    190   using FstImpl<A>::WriteHeader;
    191   using FstImpl<A>::SetInputSymbols;
    192   using FstImpl<A>::SetOutputSymbols;
    193 
    194   using CacheImpl<A>::PushArc;
    195   using CacheImpl<A>::HasArcs;
    196   using CacheImpl<A>::HasFinal;
    197   using CacheImpl<A>::HasStart;
    198   using CacheImpl<A>::SetArcs;
    199   using CacheImpl<A>::SetFinal;
    200   using CacheImpl<A>::SetStart;
    201 
    202   typedef A Arc;
    203   typedef typename A::Label   Label;
    204   typedef typename A::Weight  Weight;
    205   typedef typename A::StateId StateId;
    206   typedef CacheState<A> State;
    207 
    208   RelabelFstImpl(const Fst<A>& fst,
    209                  const vector<pair<Label, Label> >& ipairs,
    210                  const vector<pair<Label, Label> >& opairs,
    211                  const RelabelFstOptions &opts)
    212       : CacheImpl<A>(opts), fst_(fst.Copy()),
    213         relabel_input_(false), relabel_output_(false) {
    214     uint64 props = fst.Properties(kCopyProperties, false);
    215     SetProperties(RelabelProperties(props));
    216     SetType("relabel");
    217 
    218     // create input label map
    219     if (ipairs.size() > 0) {
    220       for (size_t i = 0; i < ipairs.size(); ++i) {
    221         input_map_[ipairs[i].first] = ipairs[i].second;
    222       }
    223       relabel_input_ = true;
    224     }
    225 
    226     // create output label map
    227     if (opairs.size() > 0) {
    228       for (size_t i = 0; i < opairs.size(); ++i) {
    229         output_map_[opairs[i].first] = opairs[i].second;
    230       }
    231       relabel_output_ = true;
    232     }
    233   }
    234 
    235   RelabelFstImpl(const Fst<A>& fst,
    236                  const SymbolTable* old_isymbols,
    237                  const SymbolTable* new_isymbols,
    238                  const SymbolTable* old_osymbols,
    239                  const SymbolTable* new_osymbols,
    240                  const RelabelFstOptions &opts)
    241       : CacheImpl<A>(opts), fst_(fst.Copy()),
    242         relabel_input_(false), relabel_output_(false) {
    243     SetType("relabel");
    244 
    245     uint64 props = fst.Properties(kCopyProperties, false);
    246     SetProperties(RelabelProperties(props));
    247     SetInputSymbols(old_isymbols);
    248     SetOutputSymbols(old_osymbols);
    249 
    250     if (old_isymbols && new_isymbols &&
    251         old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
    252       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    253            syms_iter.Next()) {
    254         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
    255       }
    256       SetInputSymbols(new_isymbols);
    257       relabel_input_ = true;
    258     }
    259 
    260     if (old_osymbols && new_osymbols &&
    261         old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
    262       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    263            syms_iter.Next()) {
    264         output_map_[syms_iter.Value()] =
    265           new_osymbols->Find(syms_iter.Symbol());
    266       }
    267       SetOutputSymbols(new_osymbols);
    268       relabel_output_ = true;
    269     }
    270   }
    271 
    272   RelabelFstImpl(const RelabelFstImpl<A>& impl)
    273       : CacheImpl<A>(impl),
    274         fst_(impl.fst_->Copy(true)),
    275         input_map_(impl.input_map_),
    276         output_map_(impl.output_map_),
    277         relabel_input_(impl.relabel_input_),
    278         relabel_output_(impl.relabel_output_) {
    279     SetType("relabel");
    280     SetProperties(impl.Properties(), kCopyProperties);
    281     SetInputSymbols(impl.InputSymbols());
    282     SetOutputSymbols(impl.OutputSymbols());
    283   }
    284 
    285   ~RelabelFstImpl() { delete fst_; }
    286 
    287   StateId Start() {
    288     if (!HasStart()) {
    289       StateId s = fst_->Start();
    290       SetStart(s);
    291     }
    292     return CacheImpl<A>::Start();
    293   }
    294 
    295   Weight Final(StateId s) {
    296     if (!HasFinal(s)) {
    297       SetFinal(s, fst_->Final(s));
    298     }
    299     return CacheImpl<A>::Final(s);
    300   }
    301 
    302   size_t NumArcs(StateId s) {
    303     if (!HasArcs(s)) {
    304       Expand(s);
    305     }
    306     return CacheImpl<A>::NumArcs(s);
    307   }
    308 
    309   size_t NumInputEpsilons(StateId s) {
    310     if (!HasArcs(s)) {
    311       Expand(s);
    312     }
    313     return CacheImpl<A>::NumInputEpsilons(s);
    314   }
    315 
    316   size_t NumOutputEpsilons(StateId s) {
    317     if (!HasArcs(s)) {
    318       Expand(s);
    319     }
    320     return CacheImpl<A>::NumOutputEpsilons(s);
    321   }
    322 
    323   uint64 Properties() const { return Properties(kFstProperties); }
    324 
    325   // Set error if found; return FST impl properties.
    326   uint64 Properties(uint64 mask) const {
    327     if ((mask & kError) && fst_->Properties(kError, false))
    328       SetProperties(kError, kError);
    329     return FstImpl<Arc>::Properties(mask);
    330   }
    331 
    332   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
    333     if (!HasArcs(s)) {
    334       Expand(s);
    335     }
    336     CacheImpl<A>::InitArcIterator(s, data);
    337   }
    338 
    339   void Expand(StateId s) {
    340     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
    341       A arc = aiter.Value();
    342 
    343       // relabel input
    344       if (relabel_input_) {
    345         typename unordered_map<Label, Label>::iterator it =
    346           input_map_.find(arc.ilabel);
    347         if (it != input_map_.end()) { arc.ilabel = it->second; }
    348       }
    349 
    350       // relabel output
    351       if (relabel_output_) {
    352         typename unordered_map<Label, Label>::iterator it =
    353           output_map_.find(arc.olabel);
    354         if (it != output_map_.end()) { arc.olabel = it->second; }
    355       }
    356 
    357       PushArc(s, arc);
    358     }
    359     SetArcs(s);
    360   }
    361 
    362 
    363  private:
    364   const Fst<A> *fst_;
    365 
    366   unordered_map<Label, Label> input_map_;
    367   unordered_map<Label, Label> output_map_;
    368   bool relabel_input_;
    369   bool relabel_output_;
    370 
    371   void operator=(const RelabelFstImpl<A> &);  // disallow
    372 };
    373 
    374 
    375 //
    376 // \class RelabelFst
    377 // \brief Delayed implementation of arc relabeling
    378 //
    379 // This class attaches interface to implementation and handles
    380 // reference counting, delegating most methods to ImplToFst.
    381 template <class A>
    382 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
    383  public:
    384   friend class ArcIterator< RelabelFst<A> >;
    385   friend class StateIterator< RelabelFst<A> >;
    386 
    387   typedef A Arc;
    388   typedef typename A::Label   Label;
    389   typedef typename A::Weight  Weight;
    390   typedef typename A::StateId StateId;
    391   typedef CacheState<A> State;
    392   typedef RelabelFstImpl<A> Impl;
    393 
    394   RelabelFst(const Fst<A>& fst,
    395              const vector<pair<Label, Label> >& ipairs,
    396              const vector<pair<Label, Label> >& opairs)
    397       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
    398 
    399   RelabelFst(const Fst<A>& fst,
    400              const vector<pair<Label, Label> >& ipairs,
    401              const vector<pair<Label, Label> >& opairs,
    402              const RelabelFstOptions &opts)
    403       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
    404 
    405   RelabelFst(const Fst<A>& fst,
    406              const SymbolTable* new_isymbols,
    407              const SymbolTable* new_osymbols)
    408       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
    409                                  fst.OutputSymbols(), new_osymbols,
    410                                  RelabelFstOptions())) {}
    411 
    412   RelabelFst(const Fst<A>& fst,
    413              const SymbolTable* new_isymbols,
    414              const SymbolTable* new_osymbols,
    415              const RelabelFstOptions &opts)
    416       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
    417                                  fst.OutputSymbols(), new_osymbols, opts)) {}
    418 
    419   RelabelFst(const Fst<A>& fst,
    420              const SymbolTable* old_isymbols,
    421              const SymbolTable* new_isymbols,
    422              const SymbolTable* old_osymbols,
    423              const SymbolTable* new_osymbols)
    424     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
    425                                new_osymbols, RelabelFstOptions())) {}
    426 
    427   RelabelFst(const Fst<A>& fst,
    428              const SymbolTable* old_isymbols,
    429              const SymbolTable* new_isymbols,
    430              const SymbolTable* old_osymbols,
    431              const SymbolTable* new_osymbols,
    432              const RelabelFstOptions &opts)
    433     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
    434                                new_osymbols, opts)) {}
    435 
    436   // See Fst<>::Copy() for doc.
    437   RelabelFst(const RelabelFst<A> &fst, bool safe = false)
    438     : ImplToFst<Impl>(fst, safe) {}
    439 
    440   // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
    441   virtual RelabelFst<A> *Copy(bool safe = false) const {
    442     return new RelabelFst<A>(*this, safe);
    443   }
    444 
    445   virtual void InitStateIterator(StateIteratorData<A> *data) const;
    446 
    447   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    448     return GetImpl()->InitArcIterator(s, data);
    449   }
    450 
    451  private:
    452   // Makes visible to friends.
    453   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    454 
    455   void operator=(const RelabelFst<A> &fst);  // disallow
    456 };
    457 
    458 // Specialization for RelabelFst.
    459 template<class A>
    460 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
    461  public:
    462   typedef typename A::StateId StateId;
    463 
    464   explicit StateIterator(const RelabelFst<A> &fst)
    465       : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
    466 
    467   bool Done() const { return siter_.Done(); }
    468 
    469   StateId Value() const { return s_; }
    470 
    471   void Next() {
    472     if (!siter_.Done()) {
    473       ++s_;
    474       siter_.Next();
    475     }
    476   }
    477 
    478   void Reset() {
    479     s_ = 0;
    480     siter_.Reset();
    481   }
    482 
    483  private:
    484   bool Done_() const { return Done(); }
    485   StateId Value_() const { return Value(); }
    486   void Next_() { Next(); }
    487   void Reset_() { Reset(); }
    488 
    489   const RelabelFstImpl<A> *impl_;
    490   StateIterator< Fst<A> > siter_;
    491   StateId s_;
    492 
    493   DISALLOW_COPY_AND_ASSIGN(StateIterator);
    494 };
    495 
    496 
    497 // Specialization for RelabelFst.
    498 template <class A>
    499 class ArcIterator< RelabelFst<A> >
    500     : public CacheArcIterator< RelabelFst<A> > {
    501  public:
    502   typedef typename A::StateId StateId;
    503 
    504   ArcIterator(const RelabelFst<A> &fst, StateId s)
    505       : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
    506     if (!fst.GetImpl()->HasArcs(s))
    507       fst.GetImpl()->Expand(s);
    508   }
    509 
    510  private:
    511   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    512 };
    513 
    514 template <class A> inline
    515 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    516   data->base = new StateIterator< RelabelFst<A> >(*this);
    517 }
    518 
    519 // Useful alias when using StdArc.
    520 typedef RelabelFst<StdArc> StdRelabelFst;
    521 
    522 }  // namespace fst
    523 
    524 #endif  // FST_LIB_RELABEL_H__
    525