1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: jpr (at) google.com (Jake Ratkiewicz) 16 // Convenience file for including all PDT operations at once, and/or 17 // registering them for new arc types. 18 19 #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_ 20 #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_ 21 22 #include <utility> 23 using std::pair; using std::make_pair; 24 #include <vector> 25 using std::vector; 26 27 #include <fst/compose.h> // for ComposeOptions 28 #include <fst/util.h> 29 30 #include <fst/script/fst-class.h> 31 #include <fst/script/arg-packs.h> 32 #include <fst/script/shortest-path.h> 33 34 #include <fst/extensions/pdt/compose.h> 35 #include <fst/extensions/pdt/expand.h> 36 #include <fst/extensions/pdt/info.h> 37 #include <fst/extensions/pdt/replace.h> 38 #include <fst/extensions/pdt/reverse.h> 39 #include <fst/extensions/pdt/shortest-path.h> 40 41 42 namespace fst { 43 namespace script { 44 45 // PDT COMPOSE 46 47 typedef args::Package<const FstClass &, 48 const FstClass &, 49 const vector<pair<int64, int64> >&, 50 MutableFstClass *, 51 const PdtComposeOptions &, 52 bool> PdtComposeArgs; 53 54 template<class Arc> 55 void PdtCompose(PdtComposeArgs *args) { 56 const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); 57 const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); 58 MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>(); 59 60 vector<pair<typename Arc::Label, typename Arc::Label> > parens( 61 args->arg3.size()); 62 63 for (size_t i = 0; i < parens.size(); ++i) { 64 parens[i].first = args->arg3[i].first; 65 parens[i].second = args->arg3[i].second; 66 } 67 68 if (args->arg6) { 69 Compose(ifst1, parens, ifst2, ofst, args->arg5); 70 } else { 71 Compose(ifst1, ifst2, parens, ofst, args->arg5); 72 } 73 } 74 75 void PdtCompose(const FstClass & ifst1, 76 const FstClass & ifst2, 77 const vector<pair<int64, int64> > &parens, 78 MutableFstClass *ofst, 79 const PdtComposeOptions &copts, 80 bool left_pdt); 81 82 // PDT EXPAND 83 84 struct PdtExpandOptions { 85 bool connect; 86 bool keep_parentheses; 87 WeightClass weight_threshold; 88 89 PdtExpandOptions(bool c = true, bool k = false, 90 WeightClass w = WeightClass::Zero()) 91 : connect(c), keep_parentheses(k), weight_threshold(w) {} 92 }; 93 94 typedef args::Package<const FstClass &, 95 const vector<pair<int64, int64> >&, 96 MutableFstClass *, PdtExpandOptions> PdtExpandArgs; 97 98 template<class Arc> 99 void PdtExpand(PdtExpandArgs *args) { 100 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 101 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); 102 103 vector<pair<typename Arc::Label, typename Arc::Label> > parens( 104 args->arg2.size()); 105 for (size_t i = 0; i < parens.size(); ++i) { 106 parens[i].first = args->arg2[i].first; 107 parens[i].second = args->arg2[i].second; 108 } 109 Expand(fst, parens, ofst, 110 ExpandOptions<Arc>( 111 args->arg4.connect, args->arg4.keep_parentheses, 112 *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>()))); 113 } 114 115 void PdtExpand(const FstClass &ifst, 116 const vector<pair<int64, int64> > &parens, 117 MutableFstClass *ofst, const PdtExpandOptions &opts); 118 119 void PdtExpand(const FstClass &ifst, 120 const vector<pair<int64, int64> > &parens, 121 MutableFstClass *ofst, bool connect); 122 123 // PDT REPLACE 124 125 typedef args::Package<const vector<pair<int64, const FstClass*> > &, 126 MutableFstClass *, 127 vector<pair<int64, int64> > *, 128 const int64 &> PdtReplaceArgs; 129 template<class Arc> 130 void PdtReplace(PdtReplaceArgs *args) { 131 vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples( 132 args->arg1.size()); 133 for (size_t i = 0; i < tuples.size(); ++i) { 134 tuples[i].first = args->arg1[i].first; 135 tuples[i].second = (args->arg1[i].second)->GetFst<Arc>(); 136 } 137 MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); 138 vector<pair<typename Arc::Label, typename Arc::Label> > parens( 139 args->arg3->size()); 140 141 for (size_t i = 0; i < parens.size(); ++i) { 142 parens[i].first = args->arg3->at(i).first; 143 parens[i].second = args->arg3->at(i).second; 144 } 145 Replace(tuples, ofst, &parens, args->arg4); 146 147 // now copy parens back 148 args->arg3->resize(parens.size()); 149 for (size_t i = 0; i < parens.size(); ++i) { 150 (*args->arg3)[i].first = parens[i].first; 151 (*args->arg3)[i].second = parens[i].second; 152 } 153 } 154 155 void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples, 156 MutableFstClass *ofst, 157 vector<pair<int64, int64> > *parens, 158 const int64 &root); 159 160 // PDT REVERSE 161 162 typedef args::Package<const FstClass &, 163 const vector<pair<int64, int64> >&, 164 MutableFstClass *> PdtReverseArgs; 165 166 template<class Arc> 167 void PdtReverse(PdtReverseArgs *args) { 168 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 169 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); 170 171 vector<pair<typename Arc::Label, typename Arc::Label> > parens( 172 args->arg2.size()); 173 for (size_t i = 0; i < parens.size(); ++i) { 174 parens[i].first = args->arg2[i].first; 175 parens[i].second = args->arg2[i].second; 176 } 177 Reverse(fst, parens, ofst); 178 } 179 180 void PdtReverse(const FstClass &ifst, 181 const vector<pair<int64, int64> > &parens, 182 MutableFstClass *ofst); 183 184 185 // PDT SHORTESTPATH 186 187 struct PdtShortestPathOptions { 188 QueueType queue_type; 189 bool keep_parentheses; 190 bool path_gc; 191 192 PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, 193 bool kp = false, bool gc = true) 194 : queue_type(qt), keep_parentheses(kp), path_gc(gc) {} 195 }; 196 197 typedef args::Package<const FstClass &, 198 const vector<pair<int64, int64> >&, 199 MutableFstClass *, 200 const PdtShortestPathOptions &> PdtShortestPathArgs; 201 202 template<class Arc> 203 void PdtShortestPath(PdtShortestPathArgs *args) { 204 typedef typename Arc::StateId StateId; 205 typedef typename Arc::Label Label; 206 typedef typename Arc::Weight Weight; 207 208 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 209 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); 210 const PdtShortestPathOptions &opts = args->arg4; 211 212 213 vector<pair<Label, Label> > parens(args->arg2.size()); 214 for (size_t i = 0; i < parens.size(); ++i) { 215 parens[i].first = args->arg2[i].first; 216 parens[i].second = args->arg2[i].second; 217 } 218 219 switch (opts.queue_type) { 220 default: 221 FSTERROR() << "Unknown queue type: " << opts.queue_type; 222 case FIFO_QUEUE: { 223 typedef FifoQueue<StateId> Queue; 224 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, 225 opts.path_gc); 226 ShortestPath(fst, parens, ofst, spopts); 227 return; 228 } 229 case LIFO_QUEUE: { 230 typedef LifoQueue<StateId> Queue; 231 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, 232 opts.path_gc); 233 ShortestPath(fst, parens, ofst, spopts); 234 return; 235 } 236 case STATE_ORDER_QUEUE: { 237 typedef StateOrderQueue<StateId> Queue; 238 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, 239 opts.path_gc); 240 ShortestPath(fst, parens, ofst, spopts); 241 return; 242 } 243 } 244 } 245 246 void PdtShortestPath(const FstClass &ifst, 247 const vector<pair<int64, int64> > &parens, 248 MutableFstClass *ofst, 249 const PdtShortestPathOptions &opts = 250 PdtShortestPathOptions()); 251 252 // PRINT INFO 253 254 typedef args::Package<const FstClass &, 255 const vector<pair<int64, int64> > &> PrintPdtInfoArgs; 256 257 template<class Arc> 258 void PrintPdtInfo(PrintPdtInfoArgs *args) { 259 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); 260 vector<pair<typename Arc::Label, typename Arc::Label> > parens( 261 args->arg2.size()); 262 for (size_t i = 0; i < parens.size(); ++i) { 263 parens[i].first = args->arg2[i].first; 264 parens[i].second = args->arg2[i].second; 265 } 266 PdtInfo<Arc> pdtinfo(fst, parens); 267 PrintPdtInfo(pdtinfo); 268 } 269 270 void PrintPdtInfo(const FstClass &ifst, 271 const vector<pair<int64, int64> > &parens); 272 273 } // namespace script 274 } // namespace fst 275 276 277 #define REGISTER_FST_PDT_OPERATIONS(ArcType) \ 278 REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \ 279 REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \ 280 REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \ 281 REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \ 282 REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \ 283 REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs) 284 #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_ 285