Home | History | Annotate | Download | only in fst
      1 // replace-util.h
      2 
      3 
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //     http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 // Copyright 2005-2010 Google, Inc.
     17 // Author: riley (at) google.com (Michael Riley)
     18 //
     19 
     20 // \file
     21 // Utility classes for the recursive replacement of Fsts (RTNs).
     22 
     23 #ifndef FST_LIB_REPLACE_UTIL_H__
     24 #define FST_LIB_REPLACE_UTIL_H__
     25 
     26 #include <vector>
     27 using std::vector;
     28 #include <tr1/unordered_map>
     29 using std::tr1::unordered_map;
     30 using std::tr1::unordered_multimap;
     31 #include <tr1/unordered_set>
     32 using std::tr1::unordered_set;
     33 using std::tr1::unordered_multiset;
     34 #include <map>
     35 
     36 #include <fst/connect.h>
     37 #include <fst/mutable-fst.h>
     38 #include <fst/topsort.h>
     39 
     40 
     41 namespace fst {
     42 
     43 template <class Arc>
     44 void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
     45              MutableFst<Arc> *, typename Arc::Label, bool);
     46 
     47 
     48 // Utility class for the recursive replacement of Fsts (RTNs). The
     49 // user provides a set of Label, Fst pairs at construction. These are
     50 // used by methods for testing cyclic dependencies and connectedness
     51 // and doing RTN connection and specific Fst replacement by label or
     52 // for various optimization properties. The modified results can be
     53 // obtained with the GetFstPairs() or GetMutableFstPairs() methods.
     54 template <class Arc>
     55 class ReplaceUtil {
     56  public:
     57   typedef typename Arc::Label Label;
     58   typedef typename Arc::Weight Weight;
     59   typedef typename Arc::StateId StateId;
     60 
     61   typedef pair<Label, const Fst<Arc>*> FstPair;
     62   typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
     63   typedef unordered_map<Label, Label> NonTerminalHash;
     64 
     65   // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
     66   ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
     67               Label root_label, bool epsilon_on_replace = false);
     68 
     69   // Constructs from Fsts; Fst ownership retained by caller.
     70   ReplaceUtil(const vector<FstPair> &fst_pairs,
     71               Label root_label, bool epsilon_on_replace = false);
     72 
     73   // Constructs from ReplaceFst internals; ownership retained by caller.
     74   ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
     75               const NonTerminalHash &nonterminal_hash, Label root_fst,
     76               bool epsilon_on_replace = false);
     77 
     78   ~ReplaceUtil() {
     79     for (Label i = 0; i < fst_array_.size(); ++i)
     80       delete fst_array_[i];
     81   }
     82 
     83   // True if the non-terminal dependencies are cyclic. Cyclic
     84   // dependencies will result in an unexpandable replace fst.
     85   bool CyclicDependencies() const {
     86     GetDependencies(false);
     87     return depprops_ & kCyclic;
     88   }
     89 
     90   // Returns true if no useless Fsts, states or transitions.
     91   bool Connected() const {
     92     GetDependencies(false);
     93     uint64 props = kAccessible | kCoAccessible;
     94     for (Label i = 0; i < fst_array_.size(); ++i) {
     95       if (!fst_array_[i])
     96         continue;
     97       if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
     98         return false;
     99     }
    100     return true;
    101   }
    102 
    103   // Removes useless Fsts, states and transitions.
    104   void Connect();
    105 
    106   // Replaces Fsts specified by labels.
    107   // Does nothing if there are cyclic dependencies.
    108   void ReplaceLabels(const vector<Label> &labels);
    109 
    110   // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
    111   // 'nnonterm' non-terminals (updating in reverse dependency order).
    112   // Does nothing if there are cyclic dependencies.
    113   void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
    114 
    115   // Replaces singleton Fsts.
    116   // Does nothing if there are cyclic dependencies.
    117   void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
    118 
    119   // Replaces non-terminals that have at most 'ninstances' instances
    120   // (updating in dependency order).
    121   // Does nothing if there are cyclic dependencies.
    122   void ReplaceByInstances(size_t ninstances);
    123 
    124   // Replaces non-terminals that have only one instance.
    125   // Does nothing if there are cyclic dependencies.
    126   void ReplaceUnique() { ReplaceByInstances(1); }
    127 
    128   // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
    129   void GetFstPairs(vector<FstPair> *fst_pairs);
    130 
    131   // Returns Label, MutableFst pairs; Fst ownership given to caller.
    132   void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
    133 
    134  private:
    135   // Per Fst statistics
    136   struct ReplaceStats {
    137     StateId nstates;    // # of states
    138     StateId nfinal;     // # of final states
    139     size_t narcs;       // # of arcs
    140     Label nnonterms;    // # of non-terminals in Fst
    141     size_t nref;        // # of non-terminal instances referring to this Fst
    142 
    143     // # of times that ith Fst references this Fst
    144     map<Label, size_t> inref;
    145     // # of times that this Fst references the ith Fst
    146     map<Label, size_t> outref;
    147 
    148     ReplaceStats()
    149         : nstates(0),
    150           nfinal(0),
    151           narcs(0),
    152           nnonterms(0),
    153           nref(0) {}
    154   };
    155 
    156   // Check Mutable Fsts exist o.w. create them.
    157   void CheckMutableFsts();
    158 
    159   // Computes the dependency graph of the replace Fsts.
    160   // If 'stats' is true, dependency statistics computed as well.
    161   void GetDependencies(bool stats) const;
    162 
    163   void ClearDependencies() const {
    164     depfst_.DeleteStates();
    165     stats_.clear();
    166     depprops_ = 0;
    167     have_stats_ = false;
    168   }
    169 
    170   // Get topological order of dependencies. Returns false with cyclic input.
    171   bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
    172 
    173   // Update statistics assuming that jth Fst will be replaced.
    174   void UpdateStats(Label j);
    175 
    176   Label root_label_;                              // root non-terminal
    177   Label root_fst_;                                // root Fst ID
    178   bool epsilon_on_replace_;                       // see Replace()
    179   vector<const Fst<Arc> *> fst_array_;            // Fst per ID
    180   vector<MutableFst<Arc> *> mutable_fst_array_;   // MutableFst per ID
    181   vector<Label> nonterminal_array_;               // Fst ID to non-terminal
    182   NonTerminalHash nonterminal_hash_;              // non-terminal to Fst ID
    183   mutable VectorFst<Arc> depfst_;                 // Fst ID dependencies
    184   mutable vector<bool> depaccess_;                // Fst ID accessibility
    185   mutable uint64 depprops_;                       // dependency Fst props
    186   mutable bool have_stats_;                       // have dependency statistics
    187   mutable vector<ReplaceStats> stats_;            // Per Fst statistics
    188   DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
    189 };
    190 
    191 template <class Arc>
    192 ReplaceUtil<Arc>::ReplaceUtil(
    193     const vector<MutableFstPair> &fst_pairs,
    194     Label root_label, bool epsilon_on_replace)
    195     : root_label_(root_label),
    196       epsilon_on_replace_(epsilon_on_replace),
    197       depprops_(0),
    198       have_stats_(false) {
    199   fst_array_.push_back(0);
    200   mutable_fst_array_.push_back(0);
    201   nonterminal_array_.push_back(kNoLabel);
    202   for (Label i = 0; i < fst_pairs.size(); ++i) {
    203     Label label = fst_pairs[i].first;
    204     MutableFst<Arc> *fst = fst_pairs[i].second;
    205     nonterminal_hash_[label] = fst_array_.size();
    206     nonterminal_array_.push_back(label);
    207     fst_array_.push_back(fst);
    208     mutable_fst_array_.push_back(fst);
    209   }
    210   root_fst_ = nonterminal_hash_[root_label_];
    211   if (!root_fst_)
    212     FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
    213 }
    214 
    215 template <class Arc>
    216 ReplaceUtil<Arc>::ReplaceUtil(
    217     const vector<FstPair> &fst_pairs,
    218     Label root_label, bool epsilon_on_replace)
    219     : root_label_(root_label),
    220       epsilon_on_replace_(epsilon_on_replace),
    221       depprops_(0),
    222       have_stats_(false) {
    223   fst_array_.push_back(0);
    224   nonterminal_array_.push_back(kNoLabel);
    225   for (Label i = 0; i < fst_pairs.size(); ++i) {
    226     Label label = fst_pairs[i].first;
    227     const Fst<Arc> *fst = fst_pairs[i].second;
    228     nonterminal_hash_[label] = fst_array_.size();
    229     nonterminal_array_.push_back(label);
    230     fst_array_.push_back(fst->Copy());
    231   }
    232   root_fst_ = nonterminal_hash_[root_label];
    233   if (!root_fst_)
    234     FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
    235 }
    236 
    237 template <class Arc>
    238 ReplaceUtil<Arc>::ReplaceUtil(
    239     const vector<const Fst<Arc> *> &fst_array,
    240     const NonTerminalHash &nonterminal_hash, Label root_fst,
    241     bool epsilon_on_replace)
    242     : root_fst_(root_fst),
    243       epsilon_on_replace_(epsilon_on_replace),
    244       nonterminal_array_(fst_array.size()),
    245       nonterminal_hash_(nonterminal_hash),
    246       depprops_(0),
    247       have_stats_(false) {
    248   fst_array_.push_back(0);
    249   for (Label i = 1; i < fst_array.size(); ++i)
    250     fst_array_.push_back(fst_array[i]->Copy());
    251   for (typename NonTerminalHash::const_iterator it =
    252            nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
    253     nonterminal_array_[it->second] = it->first;
    254   root_label_ = nonterminal_array_[root_fst_];
    255 }
    256 
    257 template <class Arc>
    258 void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
    259   if (depfst_.NumStates() > 0) {
    260     if (stats && !have_stats_)
    261       ClearDependencies();
    262     else
    263       return;
    264   }
    265 
    266   have_stats_ = stats;
    267   if (have_stats_)
    268     stats_.reserve(fst_array_.size());
    269 
    270   for (Label i = 0; i < fst_array_.size(); ++i) {
    271     depfst_.AddState();
    272     depfst_.SetFinal(i, Weight::One());
    273     if (have_stats_)
    274       stats_.push_back(ReplaceStats());
    275   }
    276   depfst_.SetStart(root_fst_);
    277 
    278   // An arc from each state (representing the fst) to the
    279   // state representing the fst being replaced
    280   for (Label i = 0; i < fst_array_.size(); ++i) {
    281     const Fst<Arc> *ifst = fst_array_[i];
    282     if (!ifst)
    283       continue;
    284     for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
    285       StateId s = siter.Value();
    286       if (have_stats_) {
    287         ++stats_[i].nstates;
    288         if (ifst->Final(s) != Weight::Zero())
    289           ++stats_[i].nfinal;
    290       }
    291       for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
    292            !aiter.Done(); aiter.Next()) {
    293         if (have_stats_)
    294           ++stats_[i].narcs;
    295         const Arc& arc = aiter.Value();
    296 
    297         typename NonTerminalHash::const_iterator it =
    298             nonterminal_hash_.find(arc.olabel);
    299         if (it != nonterminal_hash_.end()) {
    300           Label j = it->second;
    301           depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
    302           if (have_stats_) {
    303             ++stats_[i].nnonterms;
    304             ++stats_[j].nref;
    305             ++stats_[j].inref[i];
    306             ++stats_[i].outref[j];
    307           }
    308         }
    309       }
    310     }
    311   }
    312 
    313   // Gets accessibility info
    314   SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
    315   DfsVisit(depfst_, &scc_visitor);
    316 }
    317 
    318 template <class Arc>
    319 void ReplaceUtil<Arc>::UpdateStats(Label j) {
    320   if (!have_stats_) {
    321     FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
    322     return;
    323   }
    324 
    325   if (j == root_fst_)  // can't replace root
    326     return;
    327 
    328   typedef typename map<Label, size_t>::iterator Iter;
    329   for (Iter in = stats_[j].inref.begin();
    330        in != stats_[j].inref.end();
    331        ++in) {
    332     Label i = in->first;
    333     size_t ni = in->second;
    334     stats_[i].nstates += stats_[j].nstates * ni;
    335     stats_[i].narcs += (stats_[j].narcs + 1) * ni;  // narcs - 1 + 2 (eps)
    336     stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
    337     stats_[i].outref.erase(stats_[i].outref.find(j));
    338     for (Iter out = stats_[j].outref.begin();
    339          out != stats_[j].outref.end();
    340          ++out) {
    341       Label k = out->first;
    342       size_t nk = out->second;
    343       stats_[i].outref[k] += ni * nk;
    344     }
    345   }
    346 
    347   for (Iter out = stats_[j].outref.begin();
    348        out != stats_[j].outref.end();
    349        ++out) {
    350     Label k = out->first;
    351     size_t nk = out->second;
    352     stats_[k].nref -= nk;
    353     stats_[k].inref.erase(stats_[k].inref.find(j));
    354     for (Iter in = stats_[j].inref.begin();
    355          in != stats_[j].inref.end();
    356          ++in) {
    357       Label i = in->first;
    358       size_t ni = in->second;
    359       stats_[k].inref[i] += ni * nk;
    360       stats_[k].nref += ni * nk;
    361     }
    362   }
    363 }
    364 
    365 template <class Arc>
    366 void ReplaceUtil<Arc>::CheckMutableFsts() {
    367   if (mutable_fst_array_.size() == 0) {
    368     for (Label i = 0; i < fst_array_.size(); ++i) {
    369       if (!fst_array_[i]) {
    370         mutable_fst_array_.push_back(0);
    371       } else {
    372         mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
    373         delete fst_array_[i];
    374         fst_array_[i] = mutable_fst_array_[i];
    375       }
    376     }
    377   }
    378 }
    379 
    380 template <class Arc>
    381 void ReplaceUtil<Arc>::Connect() {
    382   CheckMutableFsts();
    383   uint64 props = kAccessible | kCoAccessible;
    384   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
    385     if (!mutable_fst_array_[i])
    386       continue;
    387     if (mutable_fst_array_[i]->Properties(props, false) != props)
    388       fst::Connect(mutable_fst_array_[i]);
    389   }
    390   GetDependencies(false);
    391   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
    392     MutableFst<Arc> *fst = mutable_fst_array_[i];
    393     if (fst && !depaccess_[i]) {
    394       delete fst;
    395       fst_array_[i] = 0;
    396       mutable_fst_array_[i] = 0;
    397     }
    398   }
    399   ClearDependencies();
    400 }
    401 
    402 template <class Arc>
    403 bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
    404                                    vector<Label> *toporder) const {
    405   // Finds topological order of dependencies.
    406   vector<StateId> order;
    407   bool acyclic = false;
    408 
    409   TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
    410   DfsVisit(fst, &top_order_visitor);
    411   if (!acyclic) {
    412     LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
    413     return false;
    414   }
    415 
    416   toporder->resize(order.size());
    417   for (Label i = 0; i < order.size(); ++i)
    418     (*toporder)[order[i]] = i;
    419 
    420   return true;
    421 }
    422 
    423 template <class Arc>
    424 void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
    425   CheckMutableFsts();
    426   unordered_set<Label> label_set;
    427   for (Label i = 0; i < labels.size(); ++i)
    428     if (labels[i] != root_label_)  // can't replace root
    429       label_set.insert(labels[i]);
    430 
    431   // Finds Fst dependencies restricted to the labels requested.
    432   GetDependencies(false);
    433   VectorFst<Arc> pfst(depfst_);
    434   for (StateId i = 0; i < pfst.NumStates(); ++i) {
    435     vector<Arc> arcs;
    436     for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
    437          !aiter.Done(); aiter.Next()) {
    438       const Arc &arc = aiter.Value();
    439       Label label = nonterminal_array_[arc.nextstate];
    440       if (label_set.count(label) > 0)
    441         arcs.push_back(arc);
    442     }
    443     pfst.DeleteArcs(i);
    444     for (size_t j = 0; j < arcs.size(); ++j)
    445       pfst.AddArc(i, arcs[j]);
    446   }
    447 
    448   vector<Label> toporder;
    449   if (!GetTopOrder(pfst, &toporder)) {
    450     ClearDependencies();
    451     return;
    452   }
    453 
    454   // Visits Fsts in reverse topological order of dependencies and
    455   // performs replacements.
    456   for (Label o = toporder.size() - 1; o >= 0;  --o) {
    457     vector<FstPair> fst_pairs;
    458     StateId s = toporder[o];
    459     for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
    460          !aiter.Done(); aiter.Next()) {
    461       const Arc &arc = aiter.Value();
    462       Label label = nonterminal_array_[arc.nextstate];
    463       const Fst<Arc> *fst = fst_array_[arc.nextstate];
    464       fst_pairs.push_back(make_pair(label, fst));
    465     }
    466     if (fst_pairs.empty())
    467         continue;
    468     Label label = nonterminal_array_[s];
    469     const Fst<Arc> *fst = fst_array_[s];
    470     fst_pairs.push_back(make_pair(label, fst));
    471 
    472     Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
    473   }
    474   ClearDependencies();
    475 }
    476 
    477 template <class Arc>
    478 void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
    479                                      size_t nnonterms) {
    480   vector<Label> labels;
    481   GetDependencies(true);
    482 
    483   vector<Label> toporder;
    484   if (!GetTopOrder(depfst_, &toporder)) {
    485     ClearDependencies();
    486     return;
    487   }
    488 
    489   for (Label o = toporder.size() - 1; o >= 0; --o) {
    490     Label j = toporder[o];
    491     if (stats_[j].nstates <= nstates &&
    492         stats_[j].narcs <= narcs &&
    493         stats_[j].nnonterms <= nnonterms) {
    494       labels.push_back(nonterminal_array_[j]);
    495       UpdateStats(j);
    496     }
    497   }
    498   ReplaceLabels(labels);
    499 }
    500 
    501 template <class Arc>
    502 void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
    503   vector<Label> labels;
    504   GetDependencies(true);
    505 
    506   vector<Label> toporder;
    507   if (!GetTopOrder(depfst_, &toporder)) {
    508     ClearDependencies();
    509     return;
    510   }
    511   for (Label o = 0; o < toporder.size(); ++o) {
    512     Label j = toporder[o];
    513     if (stats_[j].nref <= ninstances) {
    514       labels.push_back(nonterminal_array_[j]);
    515       UpdateStats(j);
    516     }
    517   }
    518   ReplaceLabels(labels);
    519 }
    520 
    521 template <class Arc>
    522 void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
    523   CheckMutableFsts();
    524   fst_pairs->clear();
    525   for (Label i = 0; i < fst_array_.size(); ++i) {
    526     Label label = nonterminal_array_[i];
    527     const Fst<Arc> *fst = fst_array_[i];
    528     if (!fst)
    529       continue;
    530     fst_pairs->push_back(make_pair(label, fst));
    531   }
    532 }
    533 
    534 template <class Arc>
    535 void ReplaceUtil<Arc>::GetMutableFstPairs(
    536     vector<MutableFstPair> *mutable_fst_pairs) {
    537   CheckMutableFsts();
    538   mutable_fst_pairs->clear();
    539   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
    540     Label label = nonterminal_array_[i];
    541     MutableFst<Arc> *fst = mutable_fst_array_[i];
    542     if (!fst)
    543       continue;
    544     mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
    545   }
    546 }
    547 
    548 }  // namespace fst
    549 
    550 #endif  // FST_LIB_REPLACE_UTIL_H__
    551