Home | History | Annotate | Download | only in fst
      1 // concat.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Functions and classes to compute the concat of two FSTs.
     20 
     21 #ifndef FST_LIB_CONCAT_H__
     22 #define FST_LIB_CONCAT_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 #include <algorithm>
     27 
     28 #include <fst/mutable-fst.h>
     29 #include <fst/rational.h>
     30 
     31 
     32 namespace fst {
     33 
     34 // Computes the concatenation (product) of two FSTs. If FST1
     35 // transduces string x to y with weight a and FST2 transduces string w
     36 // to v with weight b, then their concatenation transduces string xw
     37 // to yv with Times(a, b).
     38 //
     39 // This version modifies its MutableFst argument (in first position).
     40 //
     41 // Complexity:
     42 // - Time: O(V1 + V2 + E2)
     43 // - Space: O(V1 + V2 + E2)
     44 // where Vi = # of states and Ei = # of arcs of the ith FST.
     45 //
     46 template<class Arc>
     47 void Concat(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) {
     48   typedef typename Arc::StateId StateId;
     49   typedef typename Arc::Label Label;
     50   typedef typename Arc::Weight Weight;
     51 
     52   // TODO(riley): restore when voice actions issues fixed
     53   // Check that the symbol table are compatible
     54   if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) ||
     55       !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) {
     56     LOG(ERROR) << "Concat: input/output symbol tables of 1st argument "
     57                << "do not match input/output symbol tables of 2nd argument";
     58     // fst1->SetProperties(kError, kError);
     59     // return;
     60   }
     61 
     62   uint64 props1 = fst1->Properties(kFstProperties, false);
     63   uint64 props2 = fst2.Properties(kFstProperties, false);
     64 
     65   StateId start1 = fst1->Start();
     66   if (start1 == kNoStateId) {
     67     if (props2 & kError) fst1->SetProperties(kError, kError);
     68     return;
     69   }
     70 
     71   StateId numstates1 = fst1->NumStates();
     72   if (fst2.Properties(kExpanded, false))
     73     fst1->ReserveStates(numstates1 + CountStates(fst2));
     74 
     75   for (StateIterator< Fst<Arc> > siter2(fst2);
     76        !siter2.Done();
     77        siter2.Next()) {
     78     StateId s1 = fst1->AddState();
     79     StateId s2 = siter2.Value();
     80     fst1->SetFinal(s1, fst2.Final(s2));
     81     fst1->ReserveArcs(s1, fst2.NumArcs(s2));
     82     for (ArcIterator< Fst<Arc> > aiter(fst2, s2);
     83          !aiter.Done();
     84          aiter.Next()) {
     85       Arc arc = aiter.Value();
     86       arc.nextstate += numstates1;
     87       fst1->AddArc(s1, arc);
     88     }
     89   }
     90 
     91   StateId start2 = fst2.Start();
     92   for (StateId s1 = 0; s1 < numstates1; ++s1) {
     93     Weight final = fst1->Final(s1);
     94     if (final != Weight::Zero()) {
     95       fst1->SetFinal(s1, Weight::Zero());
     96       if (start2 != kNoStateId)
     97         fst1->AddArc(s1, Arc(0, 0, final, start2 + numstates1));
     98     }
     99   }
    100   if (start2 != kNoStateId)
    101     fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties);
    102 }
    103 
    104 // Computes the concatentation of two FSTs.  This version modifies its
    105 // MutableFst argument (in second position).
    106 //
    107 // Complexity:
    108 // - Time: O(V1 + E1)
    109 // - Space: O(V1 + E1)
    110 // where Vi = # of states and Ei = # of arcs of the ith FST.
    111 //
    112 template<class Arc>
    113 void Concat(const Fst<Arc> &fst1, MutableFst<Arc> *fst2) {
    114   typedef typename Arc::StateId StateId;
    115   typedef typename Arc::Label Label;
    116   typedef typename Arc::Weight Weight;
    117 
    118   // Check that the symbol table are compatible
    119   if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) ||
    120       !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) {
    121     LOG(ERROR) << "Concat: input/output symbol tables of 1st argument "
    122                << "do not match input/output symbol tables of 2nd argument";
    123     // fst2->SetProperties(kError, kError);
    124     // return;
    125   }
    126 
    127   uint64 props1 = fst1.Properties(kFstProperties, false);
    128   uint64 props2 = fst2->Properties(kFstProperties, false);
    129 
    130   StateId start2 = fst2->Start();
    131   if (start2 == kNoStateId) {
    132     if (props1 & kError) fst2->SetProperties(kError, kError);
    133     return;
    134   }
    135 
    136   StateId numstates2 = fst2->NumStates();
    137   if (fst1.Properties(kExpanded, false))
    138     fst2->ReserveStates(numstates2 + CountStates(fst1));
    139 
    140   for (StateIterator< Fst<Arc> > siter(fst1);
    141        !siter.Done();
    142        siter.Next()) {
    143     StateId s1 = siter.Value();
    144     StateId s2 = fst2->AddState();
    145     Weight final = fst1.Final(s1);
    146     fst2->ReserveArcs(s2, fst1.NumArcs(s1) + (final != Weight::Zero() ? 1 : 0));
    147     if (final != Weight::Zero())
    148       fst2->AddArc(s2, Arc(0, 0, final, start2));
    149     for (ArcIterator< Fst<Arc> > aiter(fst1, s1);
    150          !aiter.Done();
    151          aiter.Next()) {
    152       Arc arc = aiter.Value();
    153       arc.nextstate += numstates2;
    154       fst2->AddArc(s2, arc);
    155     }
    156   }
    157   StateId start1 = fst1.Start();
    158   fst2->SetStart(start1 == kNoStateId ? fst2->AddState() : start1 + numstates2);
    159   if (start1 != kNoStateId)
    160     fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties);
    161 }
    162 
    163 
    164 // Computes the concatentation of two FSTs. This version modifies its
    165 // RationalFst input (in first position).
    166 template<class Arc>
    167 void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) {
    168   fst1->GetImpl()->AddConcat(fst2, true);
    169 }
    170 
    171 // Computes the concatentation of two FSTs. This version modifies its
    172 // RationalFst input (in second position).
    173 template<class Arc>
    174 void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2) {
    175   fst2->GetImpl()->AddConcat(fst1, false);
    176 }
    177 
    178 typedef RationalFstOptions ConcatFstOptions;
    179 
    180 
    181 // Computes the concatenation (product) of two FSTs; this version is a
    182 // delayed Fst. If FST1 transduces string x to y with weight a and FST2
    183 // transduces string w to v with weight b, then their concatenation
    184 // transduces string xw to yv with Times(a, b).
    185 //
    186 // Complexity:
    187 // - Time: O(v1 + e1 + v2 + e2),
    188 // - Space: O(v1 + v2)
    189 // where vi = # of states visited and ei = # of arcs visited of the
    190 // ith FST. Constant time and space to visit an input state or arc is
    191 // assumed and exclusive of caching.
    192 template <class A>
    193 class ConcatFst : public RationalFst<A> {
    194  public:
    195   using ImplToFst< RationalFstImpl<A> >::GetImpl;
    196 
    197   typedef A Arc;
    198   typedef typename A::Weight Weight;
    199   typedef typename A::StateId StateId;
    200 
    201   ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2) {
    202     GetImpl()->InitConcat(fst1, fst2);
    203   }
    204 
    205   ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2,
    206             const ConcatFstOptions &opts) : RationalFst<A>(opts) {
    207     GetImpl()->InitConcat(fst1, fst2);
    208   }
    209 
    210   // See Fst<>::Copy() for doc.
    211   ConcatFst(const ConcatFst<A> &fst, bool safe = false)
    212       : RationalFst<A>(fst, safe) {}
    213 
    214   // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc.
    215   virtual ConcatFst<A> *Copy(bool safe = false) const {
    216     return new ConcatFst<A>(*this, safe);
    217   }
    218 };
    219 
    220 
    221 // Specialization for ConcatFst.
    222 template <class A>
    223 class StateIterator< ConcatFst<A> > : public StateIterator< RationalFst<A> > {
    224  public:
    225   explicit StateIterator(const ConcatFst<A> &fst)
    226       : StateIterator< RationalFst<A> >(fst) {}
    227 };
    228 
    229 
    230 // Specialization for ConcatFst.
    231 template <class A>
    232 class ArcIterator< ConcatFst<A> > : public ArcIterator< RationalFst<A> > {
    233  public:
    234   typedef typename A::StateId StateId;
    235 
    236   ArcIterator(const ConcatFst<A> &fst, StateId s)
    237       : ArcIterator< RationalFst<A> >(fst, s) {}
    238 };
    239 
    240 
    241 // Useful alias when using StdArc.
    242 typedef ConcatFst<StdArc> StdConcatFst;
    243 
    244 }  // namespace fst
    245 
    246 #endif  // FST_LIB_CONCAT_H__
    247