1 // push.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Class to reweight/push an FST. 20 21 #ifndef FST_LIB_PUSH_H__ 22 #define FST_LIB_PUSH_H__ 23 24 #include <vector> 25 using std::vector; 26 27 #include <fst/factor-weight.h> 28 #include <fst/fst.h> 29 #include <fst/arc-map.h> 30 #include <fst/reweight.h> 31 #include <fst/shortest-distance.h> 32 33 34 namespace fst { 35 36 // Private helper functions for Push 37 namespace internal { 38 39 // Compute the total weight (sum of the weights of all accepting paths) from 40 // the output of ShortestDistance. 'distance' is the shortest distance from the 41 // initial state when 'reverse == false' and to the final states when 42 // 'reverse == true'. 43 template <class Arc> 44 typename Arc::Weight ComputeTotalWeight( 45 const Fst<Arc> &fst, 46 const vector<typename Arc::Weight> &distance, 47 bool reverse) { 48 if (reverse) 49 return fst.Start() < distance.size() ? 50 distance[fst.Start()] : Arc::Weight::Zero(); 51 52 typename Arc::Weight sum = Arc::Weight::Zero(); 53 for (typename Arc::StateId s = 0; s < distance.size(); ++s) 54 sum = Plus(sum, Times(distance[s], fst.Final(s))); 55 return sum; 56 } 57 58 // Divide the weight of every accepting path by 'w'. The weight 'w' is 59 // divided at the final states if 'at_final == true' and at the 60 // initial state otherwise. 61 template <class Arc> 62 void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) { 63 if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero())) 64 return; 65 66 if (at_final) { 67 // Remove 'w' from the final states 68 for (StateIterator< MutableFst<Arc> > sit(*fst); 69 !sit.Done(); 70 sit.Next()) 71 fst->SetFinal(sit.Value(), 72 Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT)); 73 } else { // at_final == false 74 // Remove 'w' from the initial state 75 typename Arc::StateId start = fst->Start(); 76 for (MutableArcIterator<MutableFst<Arc> > ait(fst, start); 77 !ait.Done(); 78 ait.Next()) { 79 Arc arc = ait.Value(); 80 arc.weight = Divide(arc.weight, w, DIVIDE_LEFT); 81 ait.SetValue(arc); 82 } 83 fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT)); 84 } 85 } 86 } // namespace internal 87 88 // Pushes the weights in FST in the direction defined by TYPE. If 89 // pushing towards the initial state, the sum of the weight of the 90 // outgoing transitions and final weight at a non-initial state is 91 // equal to One() in the resulting machine. If pushing towards the 92 // final state, the same property holds on the reverse machine. 93 // 94 // Weight needs to be left distributive when pushing towards the 95 // initial state and right distributive when pushing towards the final 96 // states. 97 template <class Arc> 98 void Push(MutableFst<Arc> *fst, 99 ReweightType type, 100 float delta = kDelta, 101 bool remove_total_weight = false) { 102 vector<typename Arc::Weight> distance; 103 ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); 104 typename Arc::Weight total_weight = Arc::Weight::One(); 105 if (remove_total_weight) 106 total_weight = internal::ComputeTotalWeight(*fst, distance, 107 type == REWEIGHT_TO_INITIAL); 108 Reweight(fst, distance, type); 109 if (remove_total_weight) 110 internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); 111 } 112 113 const uint32 kPushWeights = 0x0001; 114 const uint32 kPushLabels = 0x0002; 115 const uint32 kPushRemoveTotalWeight = 0x0004; 116 const uint32 kPushRemoveCommonAffix = 0x0008; 117 118 // OFST obtained from IFST by pushing weights and/or labels according 119 // to PTYPE in the direction defined by RTYPE. Weight needs to be 120 // left distributive when pushing weights towards the initial state 121 // and right distributive when pushing weights towards the final 122 // states. 123 template <class Arc, ReweightType rtype> 124 void Push(const Fst<Arc> &ifst, 125 MutableFst<Arc> *ofst, 126 uint32 ptype, 127 float delta = kDelta) { 128 129 if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { 130 *ofst = ifst; 131 Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); 132 } else if (ptype & kPushLabels) { 133 const StringType stype = rtype == REWEIGHT_TO_INITIAL 134 ? STRING_LEFT 135 : STRING_RIGHT; 136 vector<typename GallicArc<Arc, stype>::Weight> gdistance; 137 VectorFst<GallicArc<Arc, stype> > gfst; 138 ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>()); 139 if (ptype & kPushWeights ) { 140 ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 141 } else { 142 ArcMapFst<Arc, Arc, RmWeightMapper<Arc> > 143 uwfst(ifst, RmWeightMapper<Arc>()); 144 ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> > 145 guwfst(uwfst, ToGallicMapper<Arc, stype>()); 146 ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 147 } 148 typename GallicArc<Arc, stype>::Weight total_weight = 149 GallicArc<Arc, stype>::Weight::One(); 150 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { 151 total_weight = internal::ComputeTotalWeight( 152 gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); 153 total_weight = typename GallicArc<Arc, stype>::Weight( 154 ptype & kPushRemoveCommonAffix ? total_weight.Value1() 155 : StringWeight<typename Arc::Label, stype>::One(), 156 ptype & kPushRemoveTotalWeight ? total_weight.Value2() 157 : Arc::Weight::One()); 158 } 159 Reweight(&gfst, gdistance, rtype); 160 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) 161 internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); 162 FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label, 163 typename Arc::Weight, stype> > fwfst(gfst); 164 ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>()); 165 ofst->SetOutputSymbols(ifst.OutputSymbols()); 166 } else { 167 LOG(WARNING) << "Push: pushing type is set to 0: " 168 << "pushing neither labels nor weights."; 169 *ofst = ifst; 170 } 171 } 172 173 } // namespace fst 174 175 #endif /* FST_LIB_PUSH_H_ */ 176