1 // reweight.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Function to reweight an FST. 20 21 #ifndef FST_LIB_REWEIGHT_H__ 22 #define FST_LIB_REWEIGHT_H__ 23 24 #include <vector> 25 using std::vector; 26 27 #include <fst/mutable-fst.h> 28 29 30 namespace fst { 31 32 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; 33 34 // Reweight FST according to the potentials defined by the POTENTIAL 35 // vector in the direction defined by TYPE. Weight needs to be left 36 // distributive when reweighting towards the initial state and right 37 // distributive when reweighting towards the final states. 38 // 39 // An arc of weight w, with an origin state of potential p and 40 // destination state of potential q, is reweighted by p\wq when 41 // reweighting towards the initial state and by pw/q when reweighting 42 // towards the final states. 43 template <class Arc> 44 void Reweight(MutableFst<Arc> *fst, 45 const vector<typename Arc::Weight> &potential, 46 ReweightType type) { 47 typedef typename Arc::Weight Weight; 48 49 if (fst->NumStates() == 0) 50 return; 51 52 if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) { 53 FSTERROR() << "Reweight: Reweighting to the final states requires " 54 << "Weight to be right distributive: " 55 << Weight::Type(); 56 fst->SetProperties(kError, kError); 57 return; 58 } 59 60 if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) { 61 FSTERROR() << "Reweight: Reweighting to the initial state requires " 62 << "Weight to be left distributive: " 63 << Weight::Type(); 64 fst->SetProperties(kError, kError); 65 return; 66 } 67 68 StateIterator< MutableFst<Arc> > sit(*fst); 69 for (; !sit.Done(); sit.Next()) { 70 typename Arc::StateId state = sit.Value(); 71 if (state == potential.size()) 72 break; 73 typename Arc::Weight weight = potential[state]; 74 if (weight != Weight::Zero()) { 75 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); 76 !ait.Done(); 77 ait.Next()) { 78 Arc arc = ait.Value(); 79 if (arc.nextstate >= potential.size()) 80 continue; 81 typename Arc::Weight nextweight = potential[arc.nextstate]; 82 if (nextweight == Weight::Zero()) 83 continue; 84 if (type == REWEIGHT_TO_INITIAL) 85 arc.weight = Divide(Times(arc.weight, nextweight), weight, 86 DIVIDE_LEFT); 87 if (type == REWEIGHT_TO_FINAL) 88 arc.weight = Divide(Times(weight, arc.weight), nextweight, 89 DIVIDE_RIGHT); 90 ait.SetValue(arc); 91 } 92 if (type == REWEIGHT_TO_INITIAL) 93 fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT)); 94 } 95 if (type == REWEIGHT_TO_FINAL) 96 fst->SetFinal(state, Times(weight, fst->Final(state))); 97 } 98 99 // This handles elements past the end of the potentials array. 100 for (; !sit.Done(); sit.Next()) { 101 typename Arc::StateId state = sit.Value(); 102 if (type == REWEIGHT_TO_FINAL) 103 fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state))); 104 } 105 106 typename Arc::Weight startweight = fst->Start() < potential.size() ? 107 potential[fst->Start()] : Weight::Zero(); 108 if ((startweight != Weight::One()) && (startweight != Weight::Zero())) { 109 if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { 110 typename Arc::StateId state = fst->Start(); 111 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); 112 !ait.Done(); 113 ait.Next()) { 114 Arc arc = ait.Value(); 115 if (type == REWEIGHT_TO_INITIAL) 116 arc.weight = Times(startweight, arc.weight); 117 else 118 arc.weight = Times( 119 Divide(Weight::One(), startweight, DIVIDE_RIGHT), 120 arc.weight); 121 ait.SetValue(arc); 122 } 123 if (type == REWEIGHT_TO_INITIAL) 124 fst->SetFinal(state, Times(startweight, fst->Final(state))); 125 else 126 fst->SetFinal(state, Times(Divide(Weight::One(), startweight, 127 DIVIDE_RIGHT), 128 fst->Final(state))); 129 } else { 130 typename Arc::StateId state = fst->AddState(); 131 Weight w = type == REWEIGHT_TO_INITIAL ? startweight : 132 Divide(Weight::One(), startweight, DIVIDE_RIGHT); 133 Arc arc(0, 0, w, fst->Start()); 134 fst->AddArc(state, arc); 135 fst->SetStart(state); 136 } 137 } 138 139 fst->SetProperties(ReweightProperties( 140 fst->Properties(kFstProperties, false)), 141 kFstProperties); 142 } 143 144 } // namespace fst 145 146 #endif // FST_LIB_REWEIGHT_H_ 147