Home | History | Annotate | Download | only in fst
      1 // reweight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Function to reweight an FST.
     20 
     21 #ifndef FST_LIB_REWEIGHT_H__
     22 #define FST_LIB_REWEIGHT_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/mutable-fst.h>
     28 
     29 
     30 namespace fst {
     31 
     32 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
     33 
     34 // Reweight FST according to the potentials defined by the POTENTIAL
     35 // vector in the direction defined by TYPE. Weight needs to be left
     36 // distributive when reweighting towards the initial state and right
     37 // distributive when reweighting towards the final states.
     38 //
     39 // An arc of weight w, with an origin state of potential p and
     40 // destination state of potential q, is reweighted by p\wq when
     41 // reweighting towards the initial state and by pw/q when reweighting
     42 // towards the final states.
     43 template <class Arc>
     44 void Reweight(MutableFst<Arc> *fst,
     45               const vector<typename Arc::Weight> &potential,
     46               ReweightType type) {
     47   typedef typename Arc::Weight Weight;
     48 
     49   if (fst->NumStates() == 0)
     50     return;
     51 
     52   if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
     53     FSTERROR() << "Reweight: Reweighting to the final states requires "
     54                << "Weight to be right distributive: "
     55                << Weight::Type();
     56     fst->SetProperties(kError, kError);
     57     return;
     58   }
     59 
     60   if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
     61     FSTERROR() << "Reweight: Reweighting to the initial state requires "
     62                << "Weight to be left distributive: "
     63                << Weight::Type();
     64     fst->SetProperties(kError, kError);
     65     return;
     66   }
     67 
     68   StateIterator< MutableFst<Arc> > sit(*fst);
     69   for (; !sit.Done(); sit.Next()) {
     70     typename Arc::StateId state = sit.Value();
     71     if (state == potential.size())
     72       break;
     73     typename Arc::Weight weight = potential[state];
     74     if (weight != Weight::Zero()) {
     75       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
     76            !ait.Done();
     77            ait.Next()) {
     78         Arc arc = ait.Value();
     79         if (arc.nextstate >= potential.size())
     80           continue;
     81         typename Arc::Weight nextweight = potential[arc.nextstate];
     82         if (nextweight == Weight::Zero())
     83           continue;
     84         if (type == REWEIGHT_TO_INITIAL)
     85           arc.weight = Divide(Times(arc.weight, nextweight), weight,
     86                               DIVIDE_LEFT);
     87         if (type == REWEIGHT_TO_FINAL)
     88           arc.weight = Divide(Times(weight, arc.weight), nextweight,
     89                               DIVIDE_RIGHT);
     90         ait.SetValue(arc);
     91       }
     92       if (type == REWEIGHT_TO_INITIAL)
     93         fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT));
     94     }
     95     if (type == REWEIGHT_TO_FINAL)
     96       fst->SetFinal(state, Times(weight, fst->Final(state)));
     97   }
     98 
     99   // This handles elements past the end of the potentials array.
    100   for (; !sit.Done(); sit.Next()) {
    101     typename Arc::StateId state = sit.Value();
    102     if (type == REWEIGHT_TO_FINAL)
    103       fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state)));
    104   }
    105 
    106   typename Arc::Weight startweight = fst->Start() < potential.size() ?
    107       potential[fst->Start()] : Weight::Zero();
    108   if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
    109     if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
    110       typename Arc::StateId state = fst->Start();
    111       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
    112            !ait.Done();
    113            ait.Next()) {
    114         Arc arc = ait.Value();
    115         if (type == REWEIGHT_TO_INITIAL)
    116           arc.weight = Times(startweight, arc.weight);
    117         else
    118           arc.weight = Times(
    119               Divide(Weight::One(), startweight, DIVIDE_RIGHT),
    120               arc.weight);
    121         ait.SetValue(arc);
    122       }
    123       if (type == REWEIGHT_TO_INITIAL)
    124         fst->SetFinal(state, Times(startweight, fst->Final(state)));
    125       else
    126         fst->SetFinal(state, Times(Divide(Weight::One(), startweight,
    127                                           DIVIDE_RIGHT),
    128                                    fst->Final(state)));
    129     } else {
    130       typename Arc::StateId state = fst->AddState();
    131       Weight w = type == REWEIGHT_TO_INITIAL ?  startweight :
    132                  Divide(Weight::One(), startweight, DIVIDE_RIGHT);
    133       Arc arc(0, 0, w, fst->Start());
    134       fst->AddArc(state, arc);
    135       fst->SetStart(state);
    136     }
    137   }
    138 
    139   fst->SetProperties(ReweightProperties(
    140                          fst->Properties(kFstProperties, false)),
    141                      kFstProperties);
    142 }
    143 
    144 }  // namespace fst
    145 
    146 #endif  // FST_LIB_REWEIGHT_H_
    147