Home | History | Annotate | Download | only in lib
      1 // reweight.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Function to reweight an FST.
     19 
     20 #ifndef FST_LIB_REWEIGHT_H__
     21 #define FST_LIB_REWEIGHT_H__
     22 
     23 #include "fst/lib/mutable-fst.h"
     24 
     25 namespace fst {
     26 
     27 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
     28 
     29 // Reweight FST according to the potentials defined by the POTENTIAL
     30 // vector in the direction defined by TYPE. Weight needs to be left
     31 // distributive when reweighting towards the initial state and right
     32 // distributive when reweighting towards the final states.
     33 //
     34 // An arc of weight w, with an origin state of potential p and
     35 // destination state of potential q, is reweighted by p\wq when
     36 // reweighting towards the initial state and by pw/q when reweighting
     37 // towards the final states.
     38 template <class Arc>
     39 void Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential,
     40               ReweightType type) {
     41   typedef typename Arc::Weight Weight;
     42 
     43   if (!fst->NumStates())
     44     return;
     45   while ( (int64)potential.size() < (int64)fst->NumStates())
     46     potential.push_back(Weight::Zero());
     47 
     48   if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring))
     49     LOG(FATAL) << "Reweight: Reweighting to the final states requires "
     50                << "Weight to be right distributive: "
     51                << Weight::Type();
     52 
     53   if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring))
     54     LOG(FATAL) << "Reweight: Reweighting to the initial state requires "
     55                << "Weight to be left distributive: "
     56                << Weight::Type();
     57 
     58   for (StateIterator< MutableFst<Arc> > sit(*fst);
     59        !sit.Done();
     60        sit.Next()) {
     61     typename Arc::StateId state = sit.Value();
     62     for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
     63          !ait.Done();
     64          ait.Next()) {
     65       Arc arc = ait.Value();
     66       if ((potential[state] == Weight::Zero()) ||
     67 	  (potential[arc.nextstate] == Weight::Zero()))
     68 	continue; //temp fix: needs to find best solution for zeros
     69       if ((type == REWEIGHT_TO_INITIAL)
     70 	  && (potential[state] != Weight::Zero()))
     71         arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]),
     72 			    potential[state], DIVIDE_LEFT);
     73       else if ((type == REWEIGHT_TO_FINAL)
     74 	       && (potential[arc.nextstate] != Weight::Zero()))
     75         arc.weight = Divide(Times(potential[state], arc.weight),
     76                             potential[arc.nextstate], DIVIDE_RIGHT);
     77       ait.SetValue(arc);
     78     }
     79     if ((type == REWEIGHT_TO_INITIAL)
     80 	&& (potential[state] != Weight::Zero()))
     81       fst->SetFinal(state,
     82                     Divide(fst->Final(state), potential[state], DIVIDE_LEFT));
     83     else if (type == REWEIGHT_TO_FINAL)
     84       fst->SetFinal(state, Times(potential[state], fst->Final(state)));
     85   }
     86 
     87   if ((potential[fst->Start()] != Weight::One()) &&
     88       (potential[fst->Start()] != Weight::Zero())) {
     89     if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
     90       typename Arc::StateId state = fst->Start();
     91       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
     92            !ait.Done();
     93            ait.Next()) {
     94         Arc arc = ait.Value();
     95         if (type == REWEIGHT_TO_INITIAL)
     96           arc.weight = Times(potential[state], arc.weight);
     97         else
     98           arc.weight = Times(
     99               Divide(Weight::One(), potential[state], DIVIDE_RIGHT),
    100               arc.weight);
    101         ait.SetValue(arc);
    102       }
    103       if (type == REWEIGHT_TO_INITIAL)
    104         fst->SetFinal(state, Times(potential[state], fst->Final(state)));
    105       else
    106         fst->SetFinal(state, Times(Divide(Weight::One(), potential[state],
    107                                           DIVIDE_RIGHT),
    108                                    fst->Final(state)));
    109     }
    110     else {
    111       typename Arc::StateId state = fst->AddState();
    112       Weight w = type == REWEIGHT_TO_INITIAL ?
    113                  potential[fst->Start()] :
    114                  Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT);
    115       Arc arc (0, 0, w, fst->Start());
    116       fst->AddArc(state, arc);
    117       fst->SetStart(state);
    118     }
    119   }
    120 
    121   fst->SetProperties(ReweightProperties(
    122                          fst->Properties(kFstProperties, false)),
    123                      kFstProperties);
    124 }
    125 
    126 }  // namespace fst
    127 
    128 #endif /* FST_LIB_REWEIGHT_H_ */
    129