Home | History | Annotate | Download | only in lib
      1 // prune.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Functions implementing pruning.
     19 
     20 #ifndef FST_LIB_PRUNE_H__
     21 #define FST_LIB_PRUNE_H__
     22 
     23 #include "fst/lib/arcfilter.h"
     24 #include "fst/lib/shortest-distance.h"
     25 
     26 namespace fst {
     27 
     28 template <class A, class ArcFilter>
     29 class PruneOptions {
     30  public:
     31   typedef typename A::Weight Weight;
     32 
     33   // Pruning threshold.
     34   Weight threshold;
     35   // Arc filter.
     36   ArcFilter filter;
     37   // If non-zero, passes in pre-computed shortest distance from initial state
     38   // (possibly resized).
     39   vector<Weight> *idistance;
     40   // If non-zero, passes in pre-computed shortest distance to final states
     41   // (possibly resized).
     42   vector<Weight> *fdistance;
     43 
     44   PruneOptions(const Weight& t, ArcFilter f, vector<Weight> *id = 0,
     45                vector<Weight> *fd = 0)
     46       : threshold(t), filter(f), idistance(id), fdistance(fd) {}
     47 };
     48 
     49 
     50 // Pruning algorithm: this version modifies its input and it takes an
     51 // options class as an argment. Delete states and arcs in 'fst' that
     52 // do not belong to a successful path whose weight is no more than
     53 // 'opts.threshold' Times() the weight of the shortest path. Weights
     54 // need to be commutative and have the path property.
     55 template <class Arc, class ArcFilter>
     56 void Prune(MutableFst<Arc> *fst,
     57            const PruneOptions<Arc, ArcFilter> &opts) {
     58   typedef typename Arc::Weight Weight;
     59   typedef typename Arc::StateId StateId;
     60 
     61   if ((Weight::Properties() & (kPath | kCommutative))
     62       != (kPath | kCommutative))
     63     LOG(FATAL) << "Prune: Weight needs to have the path property and"
     64                << " be commutative: "
     65                << Weight::Type();
     66 
     67   StateId ns = fst->NumStates();
     68   if (ns == 0) return;
     69 
     70   vector<Weight> *idistance = opts.idistance;
     71   vector<Weight> *fdistance = opts.fdistance;
     72 
     73   if (!idistance) {
     74     idistance = new vector<Weight>(ns, Weight::Zero());
     75     ShortestDistance(*fst, idistance, false);
     76   } else {
     77     idistance->resize(ns, Weight::Zero());
     78   }
     79 
     80   if (!fdistance) {
     81     fdistance = new vector<Weight>(ns, Weight::Zero());
     82     ShortestDistance(*fst, fdistance, true);
     83   } else {
     84     fdistance->resize(ns, Weight::Zero());
     85   }
     86 
     87   vector<StateId> dead;
     88   dead.push_back(fst->AddState());
     89   NaturalLess<Weight> less;
     90   Weight ceiling = Times((*fdistance)[fst->Start()], opts.threshold);
     91 
     92   for (StateId state = 0; state < ns; ++state) {
     93     if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) {
     94       dead.push_back(state);
     95       continue;
     96     }
     97     for (MutableArcIterator< MutableFst<Arc> > it(fst, state);
     98          !it.Done();
     99          it.Next()) {
    100       Arc arc = it.Value();
    101       if (!opts.filter(arc)) continue;
    102       Weight weight = Times(Times((*idistance)[state], arc.weight),
    103                            (*fdistance)[arc.nextstate]);
    104       if(less(ceiling, weight)) {
    105         arc.nextstate = dead[0];
    106         it.SetValue(arc);
    107       }
    108     }
    109     if (less(ceiling, Times((*idistance)[state], fst->Final(state))))
    110       fst->SetFinal(state, Weight::Zero());
    111   }
    112 
    113   fst->DeleteStates(dead);
    114 
    115   if (!opts.idistance)
    116     delete idistance;
    117   if (!opts.fdistance)
    118     delete fdistance;
    119 }
    120 
    121 
    122 // Pruning algorithm: this version modifies its input and simply takes
    123 // the pruning threshold as an argument. Delete states and arcs in
    124 // 'fst' that do not belong to a successful path whose weight is no
    125 // more than 'opts.threshold' Times() the weight of the shortest
    126 // path. Weights need to be commutative and have the path property.
    127 template <class Arc>
    128 void Prune(MutableFst<Arc> *fst, typename Arc::Weight threshold) {
    129   PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
    130   Prune(fst, opts);
    131 }
    132 
    133 
    134 // Pruning algorithm: this version writes the pruned input Fst to an
    135 // output MutableFst and it takes an options class as an argument.
    136 // 'ofst' contains states and arcs that belong to a successful path in
    137 // 'ifst' whose weight is no more than 'opts.threshold' Times() the
    138 // weight of the shortest path. Weights need to be commutative and
    139 // have the path property.
    140 template <class Arc, class ArcFilter>
    141 void Prune(const Fst<Arc> &ifst,
    142            MutableFst<Arc> *ofst,
    143            const PruneOptions<Arc, ArcFilter> &opts) {
    144   typedef typename Arc::Weight Weight;
    145   typedef typename Arc::StateId StateId;
    146 
    147   if ((Weight::Properties() & (kPath | kCommutative))
    148       != (kPath | kCommutative))
    149     LOG(FATAL) << "Prune: Weight needs to have the path property and"
    150                << " be commutative: "
    151                << Weight::Type();
    152 
    153   ofst->DeleteStates();
    154 
    155   if (ifst.Start() == kNoStateId)
    156     return;
    157 
    158   vector<Weight> *idistance = opts.idistance;
    159   vector<Weight> *fdistance = opts.fdistance;
    160 
    161   if (!idistance) {
    162     idistance = new vector<Weight>;
    163     ShortestDistance(ifst, idistance, false);
    164   }
    165 
    166   if (!fdistance) {
    167     fdistance = new vector<Weight>;
    168     ShortestDistance(ifst, fdistance, true);
    169   }
    170 
    171   vector<StateId> copy;
    172   NaturalLess<Weight> less;
    173   while (fdistance->size() <= ifst.Start())
    174     fdistance->push_back(Weight::Zero());
    175   Weight ceiling = Times((*fdistance)[ifst.Start()], opts.threshold);
    176 
    177   for (StateIterator< Fst<Arc> > sit(ifst);
    178        !sit.Done();
    179        sit.Next()) {
    180     StateId state = sit.Value();
    181     while (idistance->size() <= state)
    182       idistance->push_back(Weight::Zero());
    183     while (fdistance->size() <= state)
    184       fdistance->push_back(Weight::Zero());
    185     while (copy.size() <= state)
    186       copy.push_back(kNoStateId);
    187 
    188     if (less(ceiling, Times((*idistance)[state], (*fdistance)[state])))
    189       continue;
    190 
    191     if (copy[state] == kNoStateId)
    192       copy[state] = ofst->AddState();
    193     if (!less(ceiling, Times((*idistance)[state], ifst.Final(state))))
    194       ofst->SetFinal(copy[state], ifst.Final(state));
    195 
    196     for (ArcIterator< Fst<Arc> > ait(ifst, state);
    197          !ait.Done();
    198          ait.Next()) {
    199       Arc arc = ait.Value();
    200 
    201       if (!opts.filter(arc)) continue;
    202 
    203       while (idistance->size() <= arc.nextstate)
    204         idistance->push_back(Weight::Zero());
    205       while (fdistance->size() <= arc.nextstate)
    206         fdistance->push_back(Weight::Zero());
    207       while (copy.size() <= arc.nextstate)
    208         copy.push_back(kNoStateId);
    209 
    210       Weight weight = Times(Times((*idistance)[state], arc.weight),
    211                            (*fdistance)[arc.nextstate]);
    212 
    213       if (!less(ceiling, weight)) {
    214         if (copy[arc.nextstate] == kNoStateId)
    215           copy[arc.nextstate] = ofst->AddState();
    216         arc.nextstate = copy[arc.nextstate];
    217         ofst->AddArc(copy[state], arc);
    218       }
    219     }
    220   }
    221 
    222   ofst->SetStart(copy[ifst.Start()]);
    223 
    224   if (!opts.idistance)
    225     delete idistance;
    226   if (!opts.fdistance)
    227     delete fdistance;
    228 }
    229 
    230 
    231 // Pruning algorithm: this version writes the pruned input Fst to an
    232 // output MutableFst and simply takes the pruning threshold as an
    233 // argument.  'ofst' contains states and arcs that belong to a
    234 // successful path in 'ifst' whose weight is no more than
    235 // 'opts.threshold' Times() the weight of the shortest path. Weights
    236 // need to be commutative and have the path property.
    237 template <class Arc>
    238 void Prune(const Fst<Arc> &ifst,
    239            MutableFst<Arc> *ofst,
    240            typename Arc::Weight threshold) {
    241   PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
    242   Prune(ifst, ofst, opts);
    243 }
    244 
    245 } // namespace fst
    246 
    247 #endif // FST_LIB_PRUNE_H_
    248