Home | History | Annotate | Download | only in fst
      1 // prune.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Functions implementing pruning.
     20 
     21 #ifndef FST_LIB_PRUNE_H__
     22 #define FST_LIB_PRUNE_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/arcfilter.h>
     28 #include <fst/heap.h>
     29 #include <fst/shortest-distance.h>
     30 
     31 
     32 namespace fst {
     33 
     34 template <class A, class ArcFilter>
     35 class PruneOptions {
     36  public:
     37   typedef typename A::Weight Weight;
     38   typedef typename A::StateId StateId;
     39 
     40   // Pruning weight threshold.
     41   Weight weight_threshold;
     42   // Pruning state threshold.
     43   StateId state_threshold;
     44   // Arc filter.
     45   ArcFilter filter;
     46   // If non-zero, passes in pre-computed shortest distance to final states.
     47   const vector<Weight> *distance;
     48   // Determines the degree of convergence required when computing shortest
     49   // distances.
     50   float delta;
     51 
     52   explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
     53                         vector<Weight> *d = 0, float e = kDelta)
     54       : weight_threshold(w),
     55         state_threshold(s),
     56         filter(f),
     57         distance(d),
     58         delta(e) {}
     59  private:
     60   PruneOptions();  // disallow
     61 };
     62 
     63 
     64 template <class S, class W>
     65 class PruneCompare {
     66  public:
     67   typedef S StateId;
     68   typedef W Weight;
     69 
     70   PruneCompare(const vector<Weight> &idistance,
     71                const vector<Weight> &fdistance)
     72       : idistance_(idistance), fdistance_(fdistance) {}
     73 
     74   bool operator()(const StateId x, const StateId y) const {
     75     Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
     76                       x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
     77     Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
     78                       y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
     79     return less_(wx, wy);
     80   }
     81 
     82  private:
     83   const vector<Weight> &idistance_;
     84   const vector<Weight> &fdistance_;
     85   NaturalLess<Weight> less_;
     86 };
     87 
     88 
     89 
     90 // Pruning algorithm: this version modifies its input and it takes an
     91 // options class as an argment. Delete states and arcs in 'fst' that
     92 // do not belong to a successful path whose weight is no more than
     93 // the weight of the shortest path Times() 'opts.weight_threshold'.
     94 // When 'opts.state_threshold != kNoStateId', the resulting transducer
     95 // will restricted further to have at most 'opts.state_threshold'
     96 // states. Weights need to be commutative and have the path
     97 // property. The weight 'w' of any cycle needs to be bounded, i.e.,
     98 // 'Plus(w, W::One()) = One()'.
     99 template <class Arc, class ArcFilter>
    100 void Prune(MutableFst<Arc> *fst,
    101            const PruneOptions<Arc, ArcFilter> &opts) {
    102   typedef typename Arc::Weight Weight;
    103   typedef typename Arc::StateId StateId;
    104 
    105   if ((Weight::Properties() & (kPath | kCommutative))
    106       != (kPath | kCommutative)) {
    107     FSTERROR() << "Prune: Weight needs to have the path property and"
    108                << " be commutative: "
    109                << Weight::Type();
    110     fst->SetProperties(kError, kError);
    111     return;
    112   }
    113   StateId ns = fst->NumStates();
    114   if (ns == 0) return;
    115   vector<Weight> idistance(ns, Weight::Zero());
    116   vector<Weight> tmp;
    117   if (!opts.distance) {
    118     tmp.reserve(ns);
    119     ShortestDistance(*fst, &tmp, true, opts.delta);
    120   }
    121   const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
    122 
    123   if ((opts.state_threshold == 0) ||
    124       (fdistance->size() <= fst->Start()) ||
    125       ((*fdistance)[fst->Start()] == Weight::Zero())) {
    126     fst->DeleteStates();
    127     return;
    128   }
    129   PruneCompare<StateId, Weight> compare(idistance, *fdistance);
    130   Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
    131   vector<bool> visited(ns, false);
    132   vector<size_t> enqueued(ns, kNoKey);
    133   vector<StateId> dead;
    134   dead.push_back(fst->AddState());
    135   NaturalLess<Weight> less;
    136   Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
    137 
    138   StateId num_visited = 0;
    139   StateId s = fst->Start();
    140   if (!less(limit, (*fdistance)[s])) {
    141     idistance[s] = Weight::One();
    142     enqueued[s] = heap.Insert(s);
    143     ++num_visited;
    144   }
    145 
    146   while (!heap.Empty()) {
    147     s = heap.Top();
    148     heap.Pop();
    149     enqueued[s] = kNoKey;
    150     visited[s] = true;
    151     if (less(limit, Times(idistance[s], fst->Final(s))))
    152       fst->SetFinal(s, Weight::Zero());
    153     for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
    154          !ait.Done();
    155          ait.Next()) {
    156       Arc arc = ait.Value();
    157       if (!opts.filter(arc)) continue;
    158       Weight weight = Times(Times(idistance[s], arc.weight),
    159                             arc.nextstate < fdistance->size()
    160                             ? (*fdistance)[arc.nextstate]
    161                             : Weight::Zero());
    162       if (less(limit, weight)) {
    163         arc.nextstate = dead[0];
    164         ait.SetValue(arc);
    165         continue;
    166       }
    167       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
    168         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
    169       if (visited[arc.nextstate]) continue;
    170       if ((opts.state_threshold != kNoStateId) &&
    171           (num_visited >= opts.state_threshold))
    172         continue;
    173       if (enqueued[arc.nextstate] == kNoKey) {
    174         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
    175         ++num_visited;
    176       } else {
    177         heap.Update(enqueued[arc.nextstate], arc.nextstate);
    178       }
    179     }
    180   }
    181   for (size_t i = 0; i < visited.size(); ++i)
    182     if (!visited[i]) dead.push_back(i);
    183   fst->DeleteStates(dead);
    184 }
    185 
    186 
    187 // Pruning algorithm: this version modifies its input and simply takes
    188 // the pruning threshold as an argument. Delete states and arcs in
    189 // 'fst' that do not belong to a successful path whose weight is no
    190 // more than the weight of the shortest path Times()
    191 // 'weight_threshold'.  When 'state_threshold != kNoStateId', the
    192 // resulting transducer will be restricted further to have at most
    193 // 'opts.state_threshold' states. Weights need to be commutative and
    194 // have the path property. The weight 'w' of any cycle needs to be
    195 // bounded, i.e., 'Plus(w, W::One()) = One()'.
    196 template <class Arc>
    197 void Prune(MutableFst<Arc> *fst,
    198            typename Arc::Weight weight_threshold,
    199            typename Arc::StateId state_threshold = kNoStateId,
    200            double delta = kDelta) {
    201   PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
    202                                              AnyArcFilter<Arc>(), 0, delta);
    203   Prune(fst, opts);
    204 }
    205 
    206 
    207 // Pruning algorithm: this version writes the pruned input Fst to an
    208 // output MutableFst and it takes an options class as an argument.
    209 // 'ofst' contains states and arcs that belong to a successful path in
    210 // 'ifst' whose weight is no more than the weight of the shortest path
    211 // Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
    212 // kNoStateId', 'ofst' will be restricted further to have at most
    213 // 'opts.state_threshold' states. Weights need to be commutative and
    214 // have the path property. The weight 'w' of any cycle needs to be
    215 // bounded, i.e., 'Plus(w, W::One()) = One()'.
    216 template <class Arc, class ArcFilter>
    217 void Prune(const Fst<Arc> &ifst,
    218            MutableFst<Arc> *ofst,
    219            const PruneOptions<Arc, ArcFilter> &opts) {
    220   typedef typename Arc::Weight Weight;
    221   typedef typename Arc::StateId StateId;
    222 
    223   if ((Weight::Properties() & (kPath | kCommutative))
    224       != (kPath | kCommutative)) {
    225     FSTERROR() << "Prune: Weight needs to have the path property and"
    226                << " be commutative: "
    227                << Weight::Type();
    228     ofst->SetProperties(kError, kError);
    229     return;
    230   }
    231   ofst->DeleteStates();
    232   ofst->SetInputSymbols(ifst.InputSymbols());
    233   ofst->SetOutputSymbols(ifst.OutputSymbols());
    234   if (ifst.Start() == kNoStateId)
    235     return;
    236   NaturalLess<Weight> less;
    237   if (less(opts.weight_threshold, Weight::One()) ||
    238       (opts.state_threshold == 0))
    239     return;
    240   vector<Weight> idistance;
    241   vector<Weight> tmp;
    242   if (!opts.distance)
    243     ShortestDistance(ifst, &tmp, true, opts.delta);
    244   const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
    245 
    246   if ((fdistance->size() <= ifst.Start()) ||
    247       ((*fdistance)[ifst.Start()] == Weight::Zero())) {
    248     return;
    249   }
    250   PruneCompare<StateId, Weight> compare(idistance, *fdistance);
    251   Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
    252   vector<StateId> copy;
    253   vector<size_t> enqueued;
    254   vector<bool> visited;
    255 
    256   StateId s = ifst.Start();
    257   Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
    258                          opts.weight_threshold);
    259   while (copy.size() <= s)
    260     copy.push_back(kNoStateId);
    261   copy[s] = ofst->AddState();
    262   ofst->SetStart(copy[s]);
    263   while (idistance.size() <= s)
    264     idistance.push_back(Weight::Zero());
    265   idistance[s] = Weight::One();
    266   while (enqueued.size() <= s) {
    267     enqueued.push_back(kNoKey);
    268     visited.push_back(false);
    269   }
    270   enqueued[s] = heap.Insert(s);
    271 
    272   while (!heap.Empty()) {
    273     s = heap.Top();
    274     heap.Pop();
    275     enqueued[s] = kNoKey;
    276     visited[s] = true;
    277     if (!less(limit, Times(idistance[s], ifst.Final(s))))
    278       ofst->SetFinal(copy[s], ifst.Final(s));
    279     for (ArcIterator< Fst<Arc> > ait(ifst, s);
    280          !ait.Done();
    281          ait.Next()) {
    282       const Arc &arc = ait.Value();
    283       if (!opts.filter(arc)) continue;
    284       Weight weight = Times(Times(idistance[s], arc.weight),
    285                             arc.nextstate < fdistance->size()
    286                             ? (*fdistance)[arc.nextstate]
    287                             : Weight::Zero());
    288       if (less(limit, weight)) continue;
    289       if ((opts.state_threshold != kNoStateId) &&
    290           (ofst->NumStates() >= opts.state_threshold))
    291         continue;
    292       while (idistance.size() <= arc.nextstate)
    293         idistance.push_back(Weight::Zero());
    294       if (less(Times(idistance[s], arc.weight),
    295                idistance[arc.nextstate]))
    296         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
    297       while (copy.size() <= arc.nextstate)
    298         copy.push_back(kNoStateId);
    299       if (copy[arc.nextstate] == kNoStateId)
    300         copy[arc.nextstate] = ofst->AddState();
    301       ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
    302                                 copy[arc.nextstate]));
    303       while (enqueued.size() <= arc.nextstate) {
    304         enqueued.push_back(kNoKey);
    305         visited.push_back(false);
    306       }
    307       if (visited[arc.nextstate]) continue;
    308       if (enqueued[arc.nextstate] == kNoKey)
    309         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
    310       else
    311         heap.Update(enqueued[arc.nextstate], arc.nextstate);
    312     }
    313   }
    314 }
    315 
    316 
    317 // Pruning algorithm: this version writes the pruned input Fst to an
    318 // output MutableFst and simply takes the pruning threshold as an
    319 // argument.  'ofst' contains states and arcs that belong to a
    320 // successful path in 'ifst' whose weight is no more than
    321 // the weight of the shortest path Times() 'weight_threshold'. When
    322 // 'state_threshold != kNoStateId', 'ofst' will be restricted further
    323 // to have at most 'opts.state_threshold' states. Weights need to be
    324 // commutative and have the path property. The weight 'w' of any cycle
    325 // needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
    326 template <class Arc>
    327 void Prune(const Fst<Arc> &ifst,
    328            MutableFst<Arc> *ofst,
    329            typename Arc::Weight weight_threshold,
    330            typename Arc::StateId state_threshold = kNoStateId,
    331            float delta = kDelta) {
    332   PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
    333                                              AnyArcFilter<Arc>(), 0, delta);
    334   Prune(ifst, ofst, opts);
    335 }
    336 
    337 }  // namespace fst
    338 
    339 #endif // FST_LIB_PRUNE_H_
    340