Home | History | Annotate | Download | only in fst
      1 // relabel.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: johans (at) google.com (Johan Schalkwyk)
     17 //
     18 // \file
     19 // Functions and classes to relabel an Fst (either on input or output)
     20 //
     21 #ifndef FST_LIB_RELABEL_H__
     22 #define FST_LIB_RELABEL_H__
     23 
     24 #include <tr1/unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <string>
     28 #include <utility>
     29 using std::pair; using std::make_pair;
     30 #include <vector>
     31 using std::vector;
     32 
     33 #include <fst/cache.h>
     34 #include <fst/test-properties.h>
     35 
     36 
     37 #include <tr1/unordered_map>
     38 using std::tr1::unordered_map;
     39 using std::tr1::unordered_multimap;
     40 
     41 namespace fst {
     42 
     43 //
     44 // Relabels either the input labels or output labels. The old to
     45 // new labels are specified using a vector of pair<Label,Label>.
     46 // Any label associations not specified are assumed to be identity
     47 // mapping.
     48 //
     49 // \param fst input fst, must be mutable
     50 // \param ipairs vector of input label pairs indicating old to new mapping
     51 // \param opairs vector of output label pairs indicating old to new mapping
     52 //
     53 template <class A>
     54 void Relabel(
     55     MutableFst<A> *fst,
     56     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
     57     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
     58   typedef typename A::StateId StateId;
     59   typedef typename A::Label   Label;
     60 
     61   uint64 props = fst->Properties(kFstProperties, false);
     62 
     63   // construct label to label hash.
     64   unordered_map<Label, Label> input_map;
     65   for (size_t i = 0; i < ipairs.size(); ++i) {
     66     input_map[ipairs[i].first] = ipairs[i].second;
     67   }
     68 
     69   unordered_map<Label, Label> output_map;
     70   for (size_t i = 0; i < opairs.size(); ++i) {
     71     output_map[opairs[i].first] = opairs[i].second;
     72   }
     73 
     74   for (StateIterator<MutableFst<A> > siter(*fst);
     75        !siter.Done(); siter.Next()) {
     76     StateId s = siter.Value();
     77     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
     78          !aiter.Done(); aiter.Next()) {
     79       A arc = aiter.Value();
     80 
     81       // relabel input
     82       // only relabel if relabel pair defined
     83       typename unordered_map<Label, Label>::iterator it =
     84         input_map.find(arc.ilabel);
     85       if (it != input_map.end()) {
     86         if (it->second == kNoLabel) {
     87           FSTERROR() << "Input symbol id " << arc.ilabel
     88                      << " missing from target vocabulary";
     89           fst->SetProperties(kError, kError);
     90           return;
     91         }
     92         arc.ilabel = it->second;
     93       }
     94 
     95       // relabel output
     96       it = output_map.find(arc.olabel);
     97       if (it != output_map.end()) {
     98         if (it->second == kNoLabel) {
     99           FSTERROR() << "Output symbol id " << arc.olabel
    100                      << " missing from target vocabulary";
    101           fst->SetProperties(kError, kError);
    102           return;
    103         }
    104         arc.olabel = it->second;
    105       }
    106 
    107       aiter.SetValue(arc);
    108     }
    109   }
    110 
    111   fst->SetProperties(RelabelProperties(props), kFstProperties);
    112 }
    113 
    114 //
    115 // Relabels either the input labels or output labels. The old to
    116 // new labels mappings are specified using an input Symbol set.
    117 // Any label associations not specified are assumed to be identity
    118 // mapping.
    119 //
    120 // \param fst input fst, must be mutable
    121 // \param new_isymbols symbol set indicating new mapping of input symbols
    122 // \param new_osymbols symbol set indicating new mapping of output symbols
    123 //
    124 template<class A>
    125 void Relabel(MutableFst<A> *fst,
    126              const SymbolTable* new_isymbols,
    127              const SymbolTable* new_osymbols) {
    128   Relabel(fst,
    129           fst->InputSymbols(), new_isymbols, true,
    130           fst->OutputSymbols(), new_osymbols, true);
    131 }
    132 
    133 template<class A>
    134 void Relabel(MutableFst<A> *fst,
    135              const SymbolTable* old_isymbols,
    136              const SymbolTable* new_isymbols,
    137              bool attach_new_isymbols,
    138              const SymbolTable* old_osymbols,
    139              const SymbolTable* new_osymbols,
    140              bool attach_new_osymbols) {
    141   typedef typename A::StateId StateId;
    142   typedef typename A::Label   Label;
    143 
    144   vector<pair<Label, Label> > ipairs;
    145   if (old_isymbols && new_isymbols) {
    146     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    147          syms_iter.Next()) {
    148       string isymbol = syms_iter.Symbol();
    149       int isymbol_val = syms_iter.Value();
    150       int new_isymbol_val = new_isymbols->Find(isymbol);
    151       ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
    152     }
    153     if (attach_new_isymbols)
    154       fst->SetInputSymbols(new_isymbols);
    155   }
    156 
    157   vector<pair<Label, Label> > opairs;
    158   if (old_osymbols && new_osymbols) {
    159     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    160          syms_iter.Next()) {
    161       string osymbol = syms_iter.Symbol();
    162       int osymbol_val = syms_iter.Value();
    163       int new_osymbol_val = new_osymbols->Find(osymbol);
    164       opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
    165     }
    166     if (attach_new_osymbols)
    167       fst->SetOutputSymbols(new_osymbols);
    168   }
    169 
    170   // call relabel using vector of relabel pairs.
    171   Relabel(fst, ipairs, opairs);
    172 }
    173 
    174 
    175 typedef CacheOptions RelabelFstOptions;
    176 
    177 template <class A> class RelabelFst;
    178 
    179 //
    180 // \class RelabelFstImpl
    181 // \brief Implementation for delayed relabeling
    182 //
    183 // Relabels an FST from one symbol set to another. Relabeling
    184 // can either be on input or output space. RelabelFst implements
    185 // a delayed version of the relabel. Arcs are relabeled on the fly
    186 // and not cached. I.e each request is recomputed.
    187 //
    188 template<class A>
    189 class RelabelFstImpl : public CacheImpl<A> {
    190   friend class StateIterator< RelabelFst<A> >;
    191  public:
    192   using FstImpl<A>::SetType;
    193   using FstImpl<A>::SetProperties;
    194   using FstImpl<A>::WriteHeader;
    195   using FstImpl<A>::SetInputSymbols;
    196   using FstImpl<A>::SetOutputSymbols;
    197 
    198   using CacheImpl<A>::PushArc;
    199   using CacheImpl<A>::HasArcs;
    200   using CacheImpl<A>::HasFinal;
    201   using CacheImpl<A>::HasStart;
    202   using CacheImpl<A>::SetArcs;
    203   using CacheImpl<A>::SetFinal;
    204   using CacheImpl<A>::SetStart;
    205 
    206   typedef A Arc;
    207   typedef typename A::Label   Label;
    208   typedef typename A::Weight  Weight;
    209   typedef typename A::StateId StateId;
    210   typedef CacheState<A> State;
    211 
    212   RelabelFstImpl(const Fst<A>& fst,
    213                  const vector<pair<Label, Label> >& ipairs,
    214                  const vector<pair<Label, Label> >& opairs,
    215                  const RelabelFstOptions &opts)
    216       : CacheImpl<A>(opts), fst_(fst.Copy()),
    217         relabel_input_(false), relabel_output_(false) {
    218     uint64 props = fst.Properties(kCopyProperties, false);
    219     SetProperties(RelabelProperties(props));
    220     SetType("relabel");
    221 
    222     // create input label map
    223     if (ipairs.size() > 0) {
    224       for (size_t i = 0; i < ipairs.size(); ++i) {
    225         input_map_[ipairs[i].first] = ipairs[i].second;
    226       }
    227       relabel_input_ = true;
    228     }
    229 
    230     // create output label map
    231     if (opairs.size() > 0) {
    232       for (size_t i = 0; i < opairs.size(); ++i) {
    233         output_map_[opairs[i].first] = opairs[i].second;
    234       }
    235       relabel_output_ = true;
    236     }
    237   }
    238 
    239   RelabelFstImpl(const Fst<A>& fst,
    240                  const SymbolTable* old_isymbols,
    241                  const SymbolTable* new_isymbols,
    242                  const SymbolTable* old_osymbols,
    243                  const SymbolTable* new_osymbols,
    244                  const RelabelFstOptions &opts)
    245       : CacheImpl<A>(opts), fst_(fst.Copy()),
    246         relabel_input_(false), relabel_output_(false) {
    247     SetType("relabel");
    248 
    249     uint64 props = fst.Properties(kCopyProperties, false);
    250     SetProperties(RelabelProperties(props));
    251     SetInputSymbols(old_isymbols);
    252     SetOutputSymbols(old_osymbols);
    253 
    254     if (old_isymbols && new_isymbols &&
    255         old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
    256       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
    257            syms_iter.Next()) {
    258         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
    259       }
    260       SetInputSymbols(new_isymbols);
    261       relabel_input_ = true;
    262     }
    263 
    264     if (old_osymbols && new_osymbols &&
    265         old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
    266       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
    267            syms_iter.Next()) {
    268         output_map_[syms_iter.Value()] =
    269           new_osymbols->Find(syms_iter.Symbol());
    270       }
    271       SetOutputSymbols(new_osymbols);
    272       relabel_output_ = true;
    273     }
    274   }
    275 
    276   RelabelFstImpl(const RelabelFstImpl<A>& impl)
    277       : CacheImpl<A>(impl),
    278         fst_(impl.fst_->Copy(true)),
    279         input_map_(impl.input_map_),
    280         output_map_(impl.output_map_),
    281         relabel_input_(impl.relabel_input_),
    282         relabel_output_(impl.relabel_output_) {
    283     SetType("relabel");
    284     SetProperties(impl.Properties(), kCopyProperties);
    285     SetInputSymbols(impl.InputSymbols());
    286     SetOutputSymbols(impl.OutputSymbols());
    287   }
    288 
    289   ~RelabelFstImpl() { delete fst_; }
    290 
    291   StateId Start() {
    292     if (!HasStart()) {
    293       StateId s = fst_->Start();
    294       SetStart(s);
    295     }
    296     return CacheImpl<A>::Start();
    297   }
    298 
    299   Weight Final(StateId s) {
    300     if (!HasFinal(s)) {
    301       SetFinal(s, fst_->Final(s));
    302     }
    303     return CacheImpl<A>::Final(s);
    304   }
    305 
    306   size_t NumArcs(StateId s) {
    307     if (!HasArcs(s)) {
    308       Expand(s);
    309     }
    310     return CacheImpl<A>::NumArcs(s);
    311   }
    312 
    313   size_t NumInputEpsilons(StateId s) {
    314     if (!HasArcs(s)) {
    315       Expand(s);
    316     }
    317     return CacheImpl<A>::NumInputEpsilons(s);
    318   }
    319 
    320   size_t NumOutputEpsilons(StateId s) {
    321     if (!HasArcs(s)) {
    322       Expand(s);
    323     }
    324     return CacheImpl<A>::NumOutputEpsilons(s);
    325   }
    326 
    327   uint64 Properties() const { return Properties(kFstProperties); }
    328 
    329   // Set error if found; return FST impl properties.
    330   uint64 Properties(uint64 mask) const {
    331     if ((mask & kError) && fst_->Properties(kError, false))
    332       SetProperties(kError, kError);
    333     return FstImpl<Arc>::Properties(mask);
    334   }
    335 
    336   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
    337     if (!HasArcs(s)) {
    338       Expand(s);
    339     }
    340     CacheImpl<A>::InitArcIterator(s, data);
    341   }
    342 
    343   void Expand(StateId s) {
    344     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
    345       A arc = aiter.Value();
    346 
    347       // relabel input
    348       if (relabel_input_) {
    349         typename unordered_map<Label, Label>::iterator it =
    350           input_map_.find(arc.ilabel);
    351         if (it != input_map_.end()) { arc.ilabel = it->second; }
    352       }
    353 
    354       // relabel output
    355       if (relabel_output_) {
    356         typename unordered_map<Label, Label>::iterator it =
    357           output_map_.find(arc.olabel);
    358         if (it != output_map_.end()) { arc.olabel = it->second; }
    359       }
    360 
    361       PushArc(s, arc);
    362     }
    363     SetArcs(s);
    364   }
    365 
    366 
    367  private:
    368   const Fst<A> *fst_;
    369 
    370   unordered_map<Label, Label> input_map_;
    371   unordered_map<Label, Label> output_map_;
    372   bool relabel_input_;
    373   bool relabel_output_;
    374 
    375   void operator=(const RelabelFstImpl<A> &);  // disallow
    376 };
    377 
    378 
    379 //
    380 // \class RelabelFst
    381 // \brief Delayed implementation of arc relabeling
    382 //
    383 // This class attaches interface to implementation and handles
    384 // reference counting, delegating most methods to ImplToFst.
    385 template <class A>
    386 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
    387  public:
    388   friend class ArcIterator< RelabelFst<A> >;
    389   friend class StateIterator< RelabelFst<A> >;
    390 
    391   typedef A Arc;
    392   typedef typename A::Label   Label;
    393   typedef typename A::Weight  Weight;
    394   typedef typename A::StateId StateId;
    395   typedef CacheState<A> State;
    396   typedef RelabelFstImpl<A> Impl;
    397 
    398   RelabelFst(const Fst<A>& fst,
    399              const vector<pair<Label, Label> >& ipairs,
    400              const vector<pair<Label, Label> >& opairs)
    401       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
    402 
    403   RelabelFst(const Fst<A>& fst,
    404              const vector<pair<Label, Label> >& ipairs,
    405              const vector<pair<Label, Label> >& opairs,
    406              const RelabelFstOptions &opts)
    407       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
    408 
    409   RelabelFst(const Fst<A>& fst,
    410              const SymbolTable* new_isymbols,
    411              const SymbolTable* new_osymbols)
    412       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
    413                                  fst.OutputSymbols(), new_osymbols,
    414                                  RelabelFstOptions())) {}
    415 
    416   RelabelFst(const Fst<A>& fst,
    417              const SymbolTable* new_isymbols,
    418              const SymbolTable* new_osymbols,
    419              const RelabelFstOptions &opts)
    420       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
    421                                  fst.OutputSymbols(), new_osymbols, opts)) {}
    422 
    423   RelabelFst(const Fst<A>& fst,
    424              const SymbolTable* old_isymbols,
    425              const SymbolTable* new_isymbols,
    426              const SymbolTable* old_osymbols,
    427              const SymbolTable* new_osymbols)
    428     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
    429                                new_osymbols, RelabelFstOptions())) {}
    430 
    431   RelabelFst(const Fst<A>& fst,
    432              const SymbolTable* old_isymbols,
    433              const SymbolTable* new_isymbols,
    434              const SymbolTable* old_osymbols,
    435              const SymbolTable* new_osymbols,
    436              const RelabelFstOptions &opts)
    437     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
    438                                new_osymbols, opts)) {}
    439 
    440   // See Fst<>::Copy() for doc.
    441   RelabelFst(const RelabelFst<A> &fst, bool safe = false)
    442     : ImplToFst<Impl>(fst, safe) {}
    443 
    444   // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
    445   virtual RelabelFst<A> *Copy(bool safe = false) const {
    446     return new RelabelFst<A>(*this, safe);
    447   }
    448 
    449   virtual void InitStateIterator(StateIteratorData<A> *data) const;
    450 
    451   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    452     return GetImpl()->InitArcIterator(s, data);
    453   }
    454 
    455  private:
    456   // Makes visible to friends.
    457   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    458 
    459   void operator=(const RelabelFst<A> &fst);  // disallow
    460 };
    461 
    462 // Specialization for RelabelFst.
    463 template<class A>
    464 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
    465  public:
    466   typedef typename A::StateId StateId;
    467 
    468   explicit StateIterator(const RelabelFst<A> &fst)
    469       : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
    470 
    471   bool Done() const { return siter_.Done(); }
    472 
    473   StateId Value() const { return s_; }
    474 
    475   void Next() {
    476     if (!siter_.Done()) {
    477       ++s_;
    478       siter_.Next();
    479     }
    480   }
    481 
    482   void Reset() {
    483     s_ = 0;
    484     siter_.Reset();
    485   }
    486 
    487  private:
    488   bool Done_() const { return Done(); }
    489   StateId Value_() const { return Value(); }
    490   void Next_() { Next(); }
    491   void Reset_() { Reset(); }
    492 
    493   const RelabelFstImpl<A> *impl_;
    494   StateIterator< Fst<A> > siter_;
    495   StateId s_;
    496 
    497   DISALLOW_COPY_AND_ASSIGN(StateIterator);
    498 };
    499 
    500 
    501 // Specialization for RelabelFst.
    502 template <class A>
    503 class ArcIterator< RelabelFst<A> >
    504     : public CacheArcIterator< RelabelFst<A> > {
    505  public:
    506   typedef typename A::StateId StateId;
    507 
    508   ArcIterator(const RelabelFst<A> &fst, StateId s)
    509       : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
    510     if (!fst.GetImpl()->HasArcs(s))
    511       fst.GetImpl()->Expand(s);
    512   }
    513 
    514  private:
    515   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    516 };
    517 
    518 template <class A> inline
    519 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    520   data->base = new StateIterator< RelabelFst<A> >(*this);
    521 }
    522 
    523 // Useful alias when using StdArc.
    524 typedef RelabelFst<StdArc> StdRelabelFst;
    525 
    526 }  // namespace fst
    527 
    528 #endif  // FST_LIB_RELABEL_H__
    529