1 // replace.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Recursively replace Fst arcs with other Fst(s) returning a PDT. 20 21 #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ 22 #define FST_EXTENSIONS_PDT_REPLACE_H__ 23 24 #include <tr1/unordered_map> 25 using std::tr1::unordered_map; 26 using std::tr1::unordered_multimap; 27 28 #include <fst/replace.h> 29 30 namespace fst { 31 32 // Hash to paren IDs 33 template <typename S> 34 struct ReplaceParenHash { 35 size_t operator()(const pair<size_t, S> &p) const { 36 return p.first + p.second * kPrime; 37 } 38 private: 39 static const size_t kPrime = 7853; 40 }; 41 42 template <typename S> const size_t ReplaceParenHash<S>::kPrime; 43 44 // Builds a pushdown transducer (PDT) from an RTN specification 45 // identical to that in fst/lib/replace.h. The result is a PDT 46 // encoded as the FST 'ofst' where some transitions are labeled with 47 // open or close parentheses. To be interpreted as a PDT, the parens 48 // must balance on a path (see PdtExpand()). The open/close 49 // parenthesis label pairs are returned in 'parens'. 50 template <class Arc> 51 void Replace(const vector<pair<typename Arc::Label, 52 const Fst<Arc>* > >& ifst_array, 53 MutableFst<Arc> *ofst, 54 vector<pair<typename Arc::Label, 55 typename Arc::Label> > *parens, 56 typename Arc::Label root) { 57 typedef typename Arc::Label Label; 58 typedef typename Arc::StateId StateId; 59 typedef typename Arc::Weight Weight; 60 61 ofst->DeleteStates(); 62 parens->clear(); 63 64 unordered_map<Label, size_t> label2id; 65 for (size_t i = 0; i < ifst_array.size(); ++i) 66 label2id[ifst_array[i].first] = i; 67 68 Label max_label = kNoLabel; 69 size_t max_non_term_count = 0; 70 71 // Queue of non-terminals to replace 72 deque<size_t> non_term_queue; 73 // Map of non-terminals to replace to count 74 unordered_map<Label, size_t> non_term_map; 75 non_term_queue.push_back(root); 76 non_term_map[root] = 1;; 77 78 // PDT state corr. to ith replace FST start state. 79 vector<StateId> fst_start(ifst_array.size(), kNoLabel); 80 // PDT state, weight pairs corr. to ith replace FST final state & weights. 81 vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size()); 82 83 // Builds single Fst combining all referenced input Fsts. Leaves in the 84 // non-termnals for now. Tabulate the PDT states that correspond to 85 // the start and final states of the input Fsts. 86 for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) { 87 Label label = non_term_queue.front(); 88 non_term_queue.pop_front(); 89 size_t fst_id = label2id[label]; 90 91 const Fst<Arc> *ifst = ifst_array[fst_id].second; 92 for (StateIterator< Fst<Arc> > siter(*ifst); 93 !siter.Done(); siter.Next()) { 94 StateId is = siter.Value(); 95 StateId os = ofst->AddState(); 96 if (is == ifst->Start()) { 97 fst_start[fst_id] = os; 98 if (label == root) 99 ofst->SetStart(os); 100 } 101 if (ifst->Final(is) != Weight::Zero()) { 102 if (label == root) 103 ofst->SetFinal(os, ifst->Final(is)); 104 fst_final[fst_id].push_back(make_pair(os, ifst->Final(is))); 105 } 106 for (ArcIterator< Fst<Arc> > aiter(*ifst, is); 107 !aiter.Done(); aiter.Next()) { 108 Arc arc = aiter.Value(); 109 if (max_label == kNoLabel || arc.olabel > max_label) 110 max_label = arc.olabel; 111 typename unordered_map<Label, size_t>::const_iterator it = 112 label2id.find(arc.olabel); 113 if (it != label2id.end()) { 114 size_t nfst_id = it->second; 115 if (ifst_array[nfst_id].second->Start() == -1) 116 continue; 117 size_t count = non_term_map[arc.olabel]++; 118 if (count == 0) 119 non_term_queue.push_back(arc.olabel); 120 if (count > max_non_term_count) 121 max_non_term_count = count; 122 } 123 arc.nextstate += soff; 124 ofst->AddArc(os, arc); 125 } 126 } 127 } 128 129 // Changes each non-terminal transition to an open parenthesis 130 // transition redirected to the PDT state that corresponds to the 131 // start state of the input FST for the non-terminal. Adds close parenthesis 132 // transitions from the PDT states corr. to the final states of the 133 // input FST for the non-terminal to the former destination state of the 134 // non-terminal transition. 135 136 typedef MutableArcIterator< MutableFst<Arc> > MIter; 137 typedef unordered_map<pair<size_t, StateId >, size_t, 138 ReplaceParenHash<StateId> > ParenMap; 139 140 // Parenthesis pair ID per fst, state pair. 141 ParenMap paren_map; 142 // # of parenthesis pairs per fst. 143 vector<size_t> nparens(ifst_array.size(), 0); 144 // Initial open parenthesis label 145 Label first_open_paren = max_label + 1; 146 Label first_close_paren = max_label + max_non_term_count + 1; 147 148 for (StateIterator< Fst<Arc> > siter(*ofst); 149 !siter.Done(); siter.Next()) { 150 StateId os = siter.Value(); 151 MIter *aiter = new MIter(ofst, os); 152 for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) { 153 Arc arc = aiter->Value(); 154 typename unordered_map<Label, size_t>::const_iterator lit = 155 label2id.find(arc.olabel); 156 if (lit != label2id.end()) { 157 size_t nfst_id = lit->second; 158 159 // Get parentheses. Ensures distinct parenthesis pair per 160 // non-terminal and destination state but otherwise reuses them. 161 Label open_paren = kNoLabel, close_paren = kNoLabel; 162 pair<size_t, StateId> paren_key(nfst_id, arc.nextstate); 163 typename ParenMap::const_iterator pit = paren_map.find(paren_key); 164 if (pit != paren_map.end()) { 165 size_t paren_id = pit->second; 166 open_paren = (*parens)[paren_id].first; 167 close_paren = (*parens)[paren_id].second; 168 } else { 169 size_t paren_id = nparens[nfst_id]++; 170 open_paren = first_open_paren + paren_id; 171 close_paren = first_close_paren + paren_id; 172 paren_map[paren_key] = paren_id; 173 if (paren_id >= parens->size()) 174 parens->push_back(make_pair(open_paren, close_paren)); 175 } 176 177 // Sets open parenthesis. 178 Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]); 179 aiter->SetValue(sarc); 180 181 // Adds close parentheses. 182 for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) { 183 pair<StateId, Weight> &p = fst_final[nfst_id][i]; 184 Arc farc(close_paren, close_paren, p.second, arc.nextstate); 185 186 ofst->AddArc(p.first, farc); 187 if (os == p.first) { // Invalidated iterator 188 delete aiter; 189 aiter = new MIter(ofst, os); 190 aiter->Seek(n); 191 } 192 } 193 } 194 } 195 delete aiter; 196 } 197 } 198 199 } // namespace fst 200 201 #endif // FST_EXTENSIONS_PDT_REPLACE_H__ 202