1 // concat.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Functions and classes to compute the concat of two FSTs. 20 21 #ifndef FST_LIB_CONCAT_H__ 22 #define FST_LIB_CONCAT_H__ 23 24 #include <vector> 25 using std::vector; 26 #include <algorithm> 27 28 #include <fst/mutable-fst.h> 29 #include <fst/rational.h> 30 31 32 namespace fst { 33 34 // Computes the concatenation (product) of two FSTs. If FST1 35 // transduces string x to y with weight a and FST2 transduces string w 36 // to v with weight b, then their concatenation transduces string xw 37 // to yv with Times(a, b). 38 // 39 // This version modifies its MutableFst argument (in first position). 40 // 41 // Complexity: 42 // - Time: O(V1 + V2 + E2) 43 // - Space: O(V1 + V2 + E2) 44 // where Vi = # of states and Ei = # of arcs of the ith FST. 45 // 46 template<class Arc> 47 void Concat(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) { 48 typedef typename Arc::StateId StateId; 49 typedef typename Arc::Label Label; 50 typedef typename Arc::Weight Weight; 51 52 // TODO(riley): restore when voice actions issues fixed 53 // Check that the symbol table are compatible 54 if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || 55 !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { 56 LOG(ERROR) << "Concat: input/output symbol tables of 1st argument " 57 << "do not match input/output symbol tables of 2nd argument"; 58 // fst1->SetProperties(kError, kError); 59 // return; 60 } 61 62 uint64 props1 = fst1->Properties(kFstProperties, false); 63 uint64 props2 = fst2.Properties(kFstProperties, false); 64 65 StateId start1 = fst1->Start(); 66 if (start1 == kNoStateId) { 67 if (props2 & kError) fst1->SetProperties(kError, kError); 68 return; 69 } 70 71 StateId numstates1 = fst1->NumStates(); 72 if (fst2.Properties(kExpanded, false)) 73 fst1->ReserveStates(numstates1 + CountStates(fst2)); 74 75 for (StateIterator< Fst<Arc> > siter2(fst2); 76 !siter2.Done(); 77 siter2.Next()) { 78 StateId s1 = fst1->AddState(); 79 StateId s2 = siter2.Value(); 80 fst1->SetFinal(s1, fst2.Final(s2)); 81 fst1->ReserveArcs(s1, fst2.NumArcs(s2)); 82 for (ArcIterator< Fst<Arc> > aiter(fst2, s2); 83 !aiter.Done(); 84 aiter.Next()) { 85 Arc arc = aiter.Value(); 86 arc.nextstate += numstates1; 87 fst1->AddArc(s1, arc); 88 } 89 } 90 91 StateId start2 = fst2.Start(); 92 for (StateId s1 = 0; s1 < numstates1; ++s1) { 93 Weight final = fst1->Final(s1); 94 if (final != Weight::Zero()) { 95 fst1->SetFinal(s1, Weight::Zero()); 96 if (start2 != kNoStateId) 97 fst1->AddArc(s1, Arc(0, 0, final, start2 + numstates1)); 98 } 99 } 100 if (start2 != kNoStateId) 101 fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties); 102 } 103 104 // Computes the concatentation of two FSTs. This version modifies its 105 // MutableFst argument (in second position). 106 // 107 // Complexity: 108 // - Time: O(V1 + E1) 109 // - Space: O(V1 + E1) 110 // where Vi = # of states and Ei = # of arcs of the ith FST. 111 // 112 template<class Arc> 113 void Concat(const Fst<Arc> &fst1, MutableFst<Arc> *fst2) { 114 typedef typename Arc::StateId StateId; 115 typedef typename Arc::Label Label; 116 typedef typename Arc::Weight Weight; 117 118 // Check that the symbol table are compatible 119 if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) || 120 !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) { 121 LOG(ERROR) << "Concat: input/output symbol tables of 1st argument " 122 << "do not match input/output symbol tables of 2nd argument"; 123 // fst2->SetProperties(kError, kError); 124 // return; 125 } 126 127 uint64 props1 = fst1.Properties(kFstProperties, false); 128 uint64 props2 = fst2->Properties(kFstProperties, false); 129 130 StateId start2 = fst2->Start(); 131 if (start2 == kNoStateId) { 132 if (props1 & kError) fst2->SetProperties(kError, kError); 133 return; 134 } 135 136 StateId numstates2 = fst2->NumStates(); 137 if (fst1.Properties(kExpanded, false)) 138 fst2->ReserveStates(numstates2 + CountStates(fst1)); 139 140 for (StateIterator< Fst<Arc> > siter(fst1); 141 !siter.Done(); 142 siter.Next()) { 143 StateId s1 = siter.Value(); 144 StateId s2 = fst2->AddState(); 145 Weight final = fst1.Final(s1); 146 fst2->ReserveArcs(s2, fst1.NumArcs(s1) + (final != Weight::Zero() ? 1 : 0)); 147 if (final != Weight::Zero()) 148 fst2->AddArc(s2, Arc(0, 0, final, start2)); 149 for (ArcIterator< Fst<Arc> > aiter(fst1, s1); 150 !aiter.Done(); 151 aiter.Next()) { 152 Arc arc = aiter.Value(); 153 arc.nextstate += numstates2; 154 fst2->AddArc(s2, arc); 155 } 156 } 157 StateId start1 = fst1.Start(); 158 fst2->SetStart(start1 == kNoStateId ? fst2->AddState() : start1 + numstates2); 159 if (start1 != kNoStateId) 160 fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties); 161 } 162 163 164 // Computes the concatentation of two FSTs. This version modifies its 165 // RationalFst input (in first position). 166 template<class Arc> 167 void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) { 168 fst1->GetImpl()->AddConcat(fst2, true); 169 } 170 171 // Computes the concatentation of two FSTs. This version modifies its 172 // RationalFst input (in second position). 173 template<class Arc> 174 void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2) { 175 fst2->GetImpl()->AddConcat(fst1, false); 176 } 177 178 typedef RationalFstOptions ConcatFstOptions; 179 180 181 // Computes the concatenation (product) of two FSTs; this version is a 182 // delayed Fst. If FST1 transduces string x to y with weight a and FST2 183 // transduces string w to v with weight b, then their concatenation 184 // transduces string xw to yv with Times(a, b). 185 // 186 // Complexity: 187 // - Time: O(v1 + e1 + v2 + e2), 188 // - Space: O(v1 + v2) 189 // where vi = # of states visited and ei = # of arcs visited of the 190 // ith FST. Constant time and space to visit an input state or arc is 191 // assumed and exclusive of caching. 192 template <class A> 193 class ConcatFst : public RationalFst<A> { 194 public: 195 using ImplToFst< RationalFstImpl<A> >::GetImpl; 196 197 typedef A Arc; 198 typedef typename A::Weight Weight; 199 typedef typename A::StateId StateId; 200 201 ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2) { 202 GetImpl()->InitConcat(fst1, fst2); 203 } 204 205 ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2, 206 const ConcatFstOptions &opts) : RationalFst<A>(opts) { 207 GetImpl()->InitConcat(fst1, fst2); 208 } 209 210 // See Fst<>::Copy() for doc. 211 ConcatFst(const ConcatFst<A> &fst, bool safe = false) 212 : RationalFst<A>(fst, safe) {} 213 214 // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc. 215 virtual ConcatFst<A> *Copy(bool safe = false) const { 216 return new ConcatFst<A>(*this, safe); 217 } 218 }; 219 220 221 // Specialization for ConcatFst. 222 template <class A> 223 class StateIterator< ConcatFst<A> > : public StateIterator< RationalFst<A> > { 224 public: 225 explicit StateIterator(const ConcatFst<A> &fst) 226 : StateIterator< RationalFst<A> >(fst) {} 227 }; 228 229 230 // Specialization for ConcatFst. 231 template <class A> 232 class ArcIterator< ConcatFst<A> > : public ArcIterator< RationalFst<A> > { 233 public: 234 typedef typename A::StateId StateId; 235 236 ArcIterator(const ConcatFst<A> &fst, StateId s) 237 : ArcIterator< RationalFst<A> >(fst, s) {} 238 }; 239 240 241 // Useful alias when using StdArc. 242 typedef ConcatFst<StdArc> StdConcatFst; 243 244 } // namespace fst 245 246 #endif // FST_LIB_CONCAT_H__ 247