Home | History | Annotate | Download | only in fst
      1 // push.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Class to reweight/push an FST.
     20 
     21 #ifndef FST_LIB_PUSH_H__
     22 #define FST_LIB_PUSH_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/factor-weight.h>
     28 #include <fst/fst.h>
     29 #include <fst/arc-map.h>
     30 #include <fst/reweight.h>
     31 #include <fst/shortest-distance.h>
     32 
     33 
     34 namespace fst {
     35 
     36 // Private helper functions for Push
     37 namespace internal {
     38 
     39 // Compute the total weight (sum of the weights of all accepting paths) from
     40 // the output of ShortestDistance. 'distance' is the shortest distance from the
     41 // initial state when 'reverse == false' and to the final states when
     42 // 'reverse == true'.
     43 template <class Arc>
     44 typename Arc::Weight ComputeTotalWeight(
     45     const Fst<Arc> &fst,
     46     const vector<typename Arc::Weight> &distance,
     47     bool reverse) {
     48   if (reverse)
     49     return fst.Start() < distance.size() ?
     50         distance[fst.Start()] : Arc::Weight::Zero();
     51 
     52   typename Arc::Weight sum = Arc::Weight::Zero();
     53   for (typename Arc::StateId s = 0; s < distance.size(); ++s)
     54     sum = Plus(sum, Times(distance[s], fst.Final(s)));
     55   return sum;
     56 }
     57 
     58 // Divide the weight of every accepting path by 'w'. The weight 'w' is
     59 // divided at the final states if 'at_final == true' and at the
     60 // initial state otherwise.
     61 template <class Arc>
     62 void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
     63   if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
     64       return;
     65 
     66   if (at_final) {
     67     // Remove 'w' from the final states
     68     for (StateIterator< MutableFst<Arc> > sit(*fst);
     69          !sit.Done();
     70          sit.Next())
     71       fst->SetFinal(sit.Value(),
     72                     Divide(fst->Final(sit.Value()), w,  DIVIDE_RIGHT));
     73   } else {  // at_final == false
     74     // Remove 'w' from the initial state
     75     typename Arc::StateId start = fst->Start();
     76     for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
     77          !ait.Done();
     78          ait.Next()) {
     79       Arc arc = ait.Value();
     80       arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
     81       ait.SetValue(arc);
     82     }
     83     fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
     84   }
     85 }
     86 }  // namespace internal
     87 
     88 // Pushes the weights in FST in the direction defined by TYPE.  If
     89 // pushing towards the initial state, the sum of the weight of the
     90 // outgoing transitions and final weight at a non-initial state is
     91 // equal to One() in the resulting machine.  If pushing towards the
     92 // final state, the same property holds on the reverse machine.
     93 //
     94 // Weight needs to be left distributive when pushing towards the
     95 // initial state and right distributive when pushing towards the final
     96 // states.
     97 template <class Arc>
     98 void Push(MutableFst<Arc> *fst,
     99           ReweightType type,
    100           float delta = kDelta,
    101           bool remove_total_weight = false) {
    102   vector<typename Arc::Weight> distance;
    103   ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
    104   typename Arc::Weight total_weight = Arc::Weight::One();
    105   if (remove_total_weight)
    106     total_weight = internal::ComputeTotalWeight(*fst, distance,
    107                                                 type == REWEIGHT_TO_INITIAL);
    108   Reweight(fst, distance, type);
    109   if (remove_total_weight)
    110     internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
    111 }
    112 
    113 const uint32 kPushWeights = 0x0001;
    114 const uint32 kPushLabels =  0x0002;
    115 const uint32 kPushRemoveTotalWeight = 0x0004;
    116 const uint32 kPushRemoveCommonAffix = 0x0008;
    117 
    118 // OFST obtained from IFST by pushing weights and/or labels according
    119 // to PTYPE in the direction defined by RTYPE.  Weight needs to be
    120 // left distributive when pushing weights towards the initial state
    121 // and right distributive when pushing weights towards the final
    122 // states.
    123 template <class Arc, ReweightType rtype>
    124 void Push(const Fst<Arc> &ifst,
    125           MutableFst<Arc> *ofst,
    126           uint32 ptype,
    127           float delta = kDelta) {
    128 
    129   if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
    130     *ofst = ifst;
    131     Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
    132   } else if (ptype & kPushLabels) {
    133     const StringType stype = rtype == REWEIGHT_TO_INITIAL
    134                              ? STRING_LEFT
    135                              : STRING_RIGHT;
    136     vector<typename GallicArc<Arc, stype>::Weight> gdistance;
    137     VectorFst<GallicArc<Arc, stype> > gfst;
    138     ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
    139     if (ptype & kPushWeights ) {
    140       ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
    141     } else {
    142       ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
    143         uwfst(ifst, RmWeightMapper<Arc>());
    144       ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
    145         guwfst(uwfst, ToGallicMapper<Arc, stype>());
    146       ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
    147     }
    148     typename GallicArc<Arc, stype>::Weight total_weight =
    149         GallicArc<Arc, stype>::Weight::One();
    150     if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
    151       total_weight = internal::ComputeTotalWeight(
    152           gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
    153       total_weight = typename GallicArc<Arc, stype>::Weight(
    154           ptype & kPushRemoveCommonAffix ? total_weight.Value1()
    155           : StringWeight<typename Arc::Label, stype>::One(),
    156           ptype & kPushRemoveTotalWeight ? total_weight.Value2()
    157           : Arc::Weight::One());
    158     }
    159     Reweight(&gfst, gdistance, rtype);
    160     if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
    161       internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
    162     FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
    163       typename Arc::Weight, stype> > fwfst(gfst);
    164     ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
    165     ofst->SetOutputSymbols(ifst.OutputSymbols());
    166   } else {
    167     LOG(WARNING) << "Push: pushing type is set to 0: "
    168                  << "pushing neither labels nor weights.";
    169     *ofst = ifst;
    170   }
    171 }
    172 
    173 }  // namespace fst
    174 
    175 #endif /* FST_LIB_PUSH_H_ */
    176