1 // reweight.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Function to reweight an FST. 19 20 #ifndef FST_LIB_REWEIGHT_H__ 21 #define FST_LIB_REWEIGHT_H__ 22 23 #include "fst/lib/mutable-fst.h" 24 25 namespace fst { 26 27 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; 28 29 // Reweight FST according to the potentials defined by the POTENTIAL 30 // vector in the direction defined by TYPE. Weight needs to be left 31 // distributive when reweighting towards the initial state and right 32 // distributive when reweighting towards the final states. 33 // 34 // An arc of weight w, with an origin state of potential p and 35 // destination state of potential q, is reweighted by p\wq when 36 // reweighting towards the initial state and by pw/q when reweighting 37 // towards the final states. 38 template <class Arc> 39 void Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential, 40 ReweightType type) { 41 typedef typename Arc::Weight Weight; 42 43 if (!fst->NumStates()) 44 return; 45 while ( (int64)potential.size() < (int64)fst->NumStates()) 46 potential.push_back(Weight::Zero()); 47 48 if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) 49 LOG(FATAL) << "Reweight: Reweighting to the final states requires " 50 << "Weight to be right distributive: " 51 << Weight::Type(); 52 53 if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) 54 LOG(FATAL) << "Reweight: Reweighting to the initial state requires " 55 << "Weight to be left distributive: " 56 << Weight::Type(); 57 58 for (StateIterator< MutableFst<Arc> > sit(*fst); 59 !sit.Done(); 60 sit.Next()) { 61 typename Arc::StateId state = sit.Value(); 62 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); 63 !ait.Done(); 64 ait.Next()) { 65 Arc arc = ait.Value(); 66 if ((potential[state] == Weight::Zero()) || 67 (potential[arc.nextstate] == Weight::Zero())) 68 continue; //temp fix: needs to find best solution for zeros 69 if ((type == REWEIGHT_TO_INITIAL) 70 && (potential[state] != Weight::Zero())) 71 arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]), 72 potential[state], DIVIDE_LEFT); 73 else if ((type == REWEIGHT_TO_FINAL) 74 && (potential[arc.nextstate] != Weight::Zero())) 75 arc.weight = Divide(Times(potential[state], arc.weight), 76 potential[arc.nextstate], DIVIDE_RIGHT); 77 ait.SetValue(arc); 78 } 79 if ((type == REWEIGHT_TO_INITIAL) 80 && (potential[state] != Weight::Zero())) 81 fst->SetFinal(state, 82 Divide(fst->Final(state), potential[state], DIVIDE_LEFT)); 83 else if (type == REWEIGHT_TO_FINAL) 84 fst->SetFinal(state, Times(potential[state], fst->Final(state))); 85 } 86 87 if ((potential[fst->Start()] != Weight::One()) && 88 (potential[fst->Start()] != Weight::Zero())) { 89 if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { 90 typename Arc::StateId state = fst->Start(); 91 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); 92 !ait.Done(); 93 ait.Next()) { 94 Arc arc = ait.Value(); 95 if (type == REWEIGHT_TO_INITIAL) 96 arc.weight = Times(potential[state], arc.weight); 97 else 98 arc.weight = Times( 99 Divide(Weight::One(), potential[state], DIVIDE_RIGHT), 100 arc.weight); 101 ait.SetValue(arc); 102 } 103 if (type == REWEIGHT_TO_INITIAL) 104 fst->SetFinal(state, Times(potential[state], fst->Final(state))); 105 else 106 fst->SetFinal(state, Times(Divide(Weight::One(), potential[state], 107 DIVIDE_RIGHT), 108 fst->Final(state))); 109 } 110 else { 111 typename Arc::StateId state = fst->AddState(); 112 Weight w = type == REWEIGHT_TO_INITIAL ? 113 potential[fst->Start()] : 114 Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT); 115 Arc arc (0, 0, w, fst->Start()); 116 fst->AddArc(state, arc); 117 fst->SetStart(state); 118 } 119 } 120 121 fst->SetProperties(ReweightProperties( 122 fst->Properties(kFstProperties, false)), 123 kFstProperties); 124 } 125 126 } // namespace fst 127 128 #endif /* FST_LIB_REWEIGHT_H_ */ 129