1 // compose.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Compose a PDT and an FST. 20 21 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__ 22 #define FST_EXTENSIONS_PDT_COMPOSE_H__ 23 24 #include <list> 25 26 #include <fst/extensions/pdt/pdt.h> 27 #include <fst/compose.h> 28 29 namespace fst { 30 31 // Return paren arcs for Find(kNoLabel). 32 const uint32 kParenList = 0x00000001; 33 34 // Return a kNolabel loop for Find(paren). 35 const uint32 kParenLoop = 0x00000002; 36 37 // This class is a matcher that treats parens as multi-epsilon labels. 38 // It is most efficient if the parens are in a range non-overlapping with 39 // the non-paren labels. 40 template <class F> 41 class ParenMatcher { 42 public: 43 typedef SortedMatcher<F> M; 44 typedef typename M::FST FST; 45 typedef typename M::Arc Arc; 46 typedef typename Arc::StateId StateId; 47 typedef typename Arc::Label Label; 48 typedef typename Arc::Weight Weight; 49 50 ParenMatcher(const FST &fst, MatchType match_type, 51 uint32 flags = (kParenLoop | kParenList)) 52 : matcher_(fst, match_type), 53 match_type_(match_type), 54 flags_(flags) { 55 if (match_type == MATCH_INPUT) { 56 loop_.ilabel = kNoLabel; 57 loop_.olabel = 0; 58 } else { 59 loop_.ilabel = 0; 60 loop_.olabel = kNoLabel; 61 } 62 loop_.weight = Weight::One(); 63 loop_.nextstate = kNoStateId; 64 } 65 66 ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false) 67 : matcher_(matcher.matcher_, safe), 68 match_type_(matcher.match_type_), 69 flags_(matcher.flags_), 70 open_parens_(matcher.open_parens_), 71 close_parens_(matcher.close_parens_), 72 loop_(matcher.loop_) { 73 loop_.nextstate = kNoStateId; 74 } 75 76 ParenMatcher<F> *Copy(bool safe = false) const { 77 return new ParenMatcher<F>(*this, safe); 78 } 79 80 MatchType Type(bool test) const { return matcher_.Type(test); } 81 82 void SetState(StateId s) { 83 matcher_.SetState(s); 84 loop_.nextstate = s; 85 } 86 87 bool Find(Label match_label); 88 89 bool Done() const { 90 return done_; 91 } 92 93 const Arc& Value() const { 94 return paren_loop_ ? loop_ : matcher_.Value(); 95 } 96 97 void Next(); 98 99 const FST &GetFst() const { return matcher_.GetFst(); } 100 101 uint64 Properties(uint64 props) const { return matcher_.Properties(props); } 102 103 uint32 Flags() const { return matcher_.Flags(); } 104 105 void AddOpenParen(Label label) { 106 if (label == 0) { 107 FSTERROR() << "ParenMatcher: Bad open paren label: 0"; 108 } else { 109 open_parens_.Insert(label); 110 } 111 } 112 113 void AddCloseParen(Label label) { 114 if (label == 0) { 115 FSTERROR() << "ParenMatcher: Bad close paren label: 0"; 116 } else { 117 close_parens_.Insert(label); 118 } 119 } 120 121 void RemoveOpenParen(Label label) { 122 if (label == 0) { 123 FSTERROR() << "ParenMatcher: Bad open paren label: 0"; 124 } else { 125 open_parens_.Erase(label); 126 } 127 } 128 129 void RemoveCloseParen(Label label) { 130 if (label == 0) { 131 FSTERROR() << "ParenMatcher: Bad close paren label: 0"; 132 } else { 133 close_parens_.Erase(label); 134 } 135 } 136 137 void ClearOpenParens() { 138 open_parens_.Clear(); 139 } 140 141 void ClearCloseParens() { 142 close_parens_.Clear(); 143 } 144 145 bool IsOpenParen(Label label) const { 146 return open_parens_.Member(label); 147 } 148 149 bool IsCloseParen(Label label) const { 150 return close_parens_.Member(label); 151 } 152 153 private: 154 // Advances matcher to next open paren if it exists, returning true. 155 // O.w. returns false. 156 bool NextOpenParen(); 157 158 // Advances matcher to next open paren if it exists, returning true. 159 // O.w. returns false. 160 bool NextCloseParen(); 161 162 M matcher_; 163 MatchType match_type_; // Type of match to perform 164 uint32 flags_; 165 166 // open paren label set 167 CompactSet<Label, kNoLabel> open_parens_; 168 169 // close paren label set 170 CompactSet<Label, kNoLabel> close_parens_; 171 172 173 bool open_paren_list_; // Matching open paren list 174 bool close_paren_list_; // Matching close paren list 175 bool paren_loop_; // Current arc is the implicit paren loop 176 mutable Arc loop_; // For non-consuming symbols 177 bool done_; // Matching done 178 179 void operator=(const ParenMatcher<F> &); // Disallow 180 }; 181 182 template <class M> inline 183 bool ParenMatcher<M>::Find(Label match_label) { 184 open_paren_list_ = false; 185 close_paren_list_ = false; 186 paren_loop_ = false; 187 done_ = false; 188 189 // Returns all parenthesis arcs 190 if (match_label == kNoLabel && (flags_ & kParenList)) { 191 if (open_parens_.LowerBound() != kNoLabel) { 192 matcher_.LowerBound(open_parens_.LowerBound()); 193 open_paren_list_ = NextOpenParen(); 194 if (open_paren_list_) return true; 195 } 196 if (close_parens_.LowerBound() != kNoLabel) { 197 matcher_.LowerBound(close_parens_.LowerBound()); 198 close_paren_list_ = NextCloseParen(); 199 if (close_paren_list_) return true; 200 } 201 } 202 203 // Returns 'implicit' paren loop 204 if (match_label > 0 && (flags_ & kParenLoop) && 205 (IsOpenParen(match_label) || IsCloseParen(match_label))) { 206 paren_loop_ = true; 207 return true; 208 } 209 210 // Returns all other labels 211 if (matcher_.Find(match_label)) 212 return true; 213 214 done_ = true; 215 return false; 216 } 217 218 template <class F> inline 219 void ParenMatcher<F>::Next() { 220 if (paren_loop_) { 221 paren_loop_ = false; 222 done_ = true; 223 } else if (open_paren_list_) { 224 matcher_.Next(); 225 open_paren_list_ = NextOpenParen(); 226 if (open_paren_list_) return; 227 228 if (close_parens_.LowerBound() != kNoLabel) { 229 matcher_.LowerBound(close_parens_.LowerBound()); 230 close_paren_list_ = NextCloseParen(); 231 if (close_paren_list_) return; 232 } 233 done_ = !matcher_.Find(kNoLabel); 234 } else if (close_paren_list_) { 235 matcher_.Next(); 236 close_paren_list_ = NextCloseParen(); 237 if (close_paren_list_) return; 238 done_ = !matcher_.Find(kNoLabel); 239 } else { 240 matcher_.Next(); 241 done_ = matcher_.Done(); 242 } 243 } 244 245 // Advances matcher to next open paren if it exists, returning true. 246 // O.w. returns false. 247 template <class F> inline 248 bool ParenMatcher<F>::NextOpenParen() { 249 for (; !matcher_.Done(); matcher_.Next()) { 250 Label label = match_type_ == MATCH_INPUT ? 251 matcher_.Value().ilabel : matcher_.Value().olabel; 252 if (label > open_parens_.UpperBound()) 253 return false; 254 if (IsOpenParen(label)) 255 return true; 256 } 257 return false; 258 } 259 260 // Advances matcher to next close paren if it exists, returning true. 261 // O.w. returns false. 262 template <class F> inline 263 bool ParenMatcher<F>::NextCloseParen() { 264 for (; !matcher_.Done(); matcher_.Next()) { 265 Label label = match_type_ == MATCH_INPUT ? 266 matcher_.Value().ilabel : matcher_.Value().olabel; 267 if (label > close_parens_.UpperBound()) 268 return false; 269 if (IsCloseParen(label)) 270 return true; 271 } 272 return false; 273 } 274 275 276 template <class F> 277 class ParenFilter { 278 public: 279 typedef typename F::FST1 FST1; 280 typedef typename F::FST2 FST2; 281 typedef typename F::Arc Arc; 282 typedef typename Arc::StateId StateId; 283 typedef typename Arc::Label Label; 284 typedef typename Arc::Weight Weight; 285 typedef typename F::Matcher1 Matcher1; 286 typedef typename F::Matcher2 Matcher2; 287 typedef typename F::FilterState FilterState1; 288 typedef StateId StackId; 289 typedef PdtStack<StackId, Label> ParenStack; 290 typedef IntegerFilterState<StackId> FilterState2; 291 typedef PairFilterState<FilterState1, FilterState2> FilterState; 292 typedef ParenFilter<F> Filter; 293 294 ParenFilter(const FST1 &fst1, const FST2 &fst2, 295 Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0, 296 const vector<pair<Label, Label> > *parens = 0, 297 bool expand = false, bool keep_parens = true) 298 : filter_(fst1, fst2, matcher1, matcher2), 299 parens_(parens ? *parens : vector<pair<Label, Label> >()), 300 expand_(expand), 301 keep_parens_(keep_parens), 302 f_(FilterState::NoState()), 303 stack_(parens_), 304 paren_id_(-1) { 305 if (parens) { 306 for (size_t i = 0; i < parens->size(); ++i) { 307 const pair<Label, Label> &p = (*parens)[i]; 308 parens_.push_back(p); 309 GetMatcher1()->AddOpenParen(p.first); 310 GetMatcher2()->AddOpenParen(p.first); 311 if (!expand_) { 312 GetMatcher1()->AddCloseParen(p.second); 313 GetMatcher2()->AddCloseParen(p.second); 314 } 315 } 316 } 317 } 318 319 ParenFilter(const Filter &filter, bool safe = false) 320 : filter_(filter.filter_, safe), 321 parens_(filter.parens_), 322 expand_(filter.expand_), 323 keep_parens_(filter.keep_parens_), 324 f_(FilterState::NoState()), 325 stack_(filter.parens_), 326 paren_id_(-1) { } 327 328 FilterState Start() const { 329 return FilterState(filter_.Start(), FilterState2(0)); 330 } 331 332 void SetState(StateId s1, StateId s2, const FilterState &f) { 333 f_ = f; 334 filter_.SetState(s1, s2, f_.GetState1()); 335 if (!expand_) 336 return; 337 338 ssize_t paren_id = stack_.Top(f.GetState2().GetState()); 339 if (paren_id != paren_id_) { 340 if (paren_id_ != -1) { 341 GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); 342 GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); 343 } 344 paren_id_ = paren_id; 345 if (paren_id_ != -1) { 346 GetMatcher1()->AddCloseParen(parens_[paren_id_].second); 347 GetMatcher2()->AddCloseParen(parens_[paren_id_].second); 348 } 349 } 350 } 351 352 FilterState FilterArc(Arc *arc1, Arc *arc2) const { 353 FilterState1 f1 = filter_.FilterArc(arc1, arc2); 354 const FilterState2 &f2 = f_.GetState2(); 355 if (f1 == FilterState1::NoState()) 356 return FilterState::NoState(); 357 358 if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses 359 if (keep_parens_) { 360 arc1->ilabel = arc2->ilabel; 361 } else if (arc2->ilabel) { 362 arc2->olabel = arc1->ilabel; 363 } 364 return FilterParen(arc2->ilabel, f1, f2); 365 } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses 366 if (keep_parens_) { 367 arc2->olabel = arc1->olabel; 368 } else { 369 arc1->ilabel = arc2->olabel; 370 } 371 return FilterParen(arc1->olabel, f1, f2); 372 } else { 373 return FilterState(f1, f2); 374 } 375 } 376 377 void FilterFinal(Weight *w1, Weight *w2) const { 378 if (f_.GetState2().GetState() != 0) 379 *w1 = Weight::Zero(); 380 filter_.FilterFinal(w1, w2); 381 } 382 383 // Return resp matchers. Ownership stays with filter. 384 Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } 385 Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } 386 387 uint64 Properties(uint64 iprops) const { 388 uint64 oprops = filter_.Properties(iprops); 389 return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; 390 } 391 392 private: 393 const FilterState FilterParen(Label label, const FilterState1 &f1, 394 const FilterState2 &f2) const { 395 if (!expand_) 396 return FilterState(f1, f2); 397 398 StackId stack_id = stack_.Find(f2.GetState(), label); 399 if (stack_id < 0) { 400 return FilterState::NoState(); 401 } else { 402 return FilterState(f1, FilterState2(stack_id)); 403 } 404 } 405 406 F filter_; 407 vector<pair<Label, Label> > parens_; 408 bool expand_; // Expands to FST 409 bool keep_parens_; // Retains parentheses in output 410 FilterState f_; // Current filter state 411 mutable ParenStack stack_; 412 ssize_t paren_id_; 413 }; 414 415 // Class to setup composition options for PDT composition. 416 // Default is for the PDT as the first composition argument. 417 template <class Arc, bool left_pdt = true> 418 class PdtComposeFstOptions : public 419 ComposeFstOptions<Arc, 420 ParenMatcher< Fst<Arc> >, 421 ParenFilter<AltSequenceComposeFilter< 422 ParenMatcher< Fst<Arc> > > > > { 423 public: 424 typedef typename Arc::Label Label; 425 typedef ParenMatcher< Fst<Arc> > PdtMatcher; 426 typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; 427 typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; 428 using COptions::matcher1; 429 using COptions::matcher2; 430 using COptions::filter; 431 432 PdtComposeFstOptions(const Fst<Arc> &ifst1, 433 const vector<pair<Label, Label> > &parens, 434 const Fst<Arc> &ifst2, bool expand = false, 435 bool keep_parens = true) { 436 matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); 437 matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); 438 439 filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, 440 expand, keep_parens); 441 } 442 }; 443 444 // Class to setup composition options for PDT with FST composition. 445 // Specialization is for the FST as the first composition argument. 446 template <class Arc> 447 class PdtComposeFstOptions<Arc, false> : public 448 ComposeFstOptions<Arc, 449 ParenMatcher< Fst<Arc> >, 450 ParenFilter<SequenceComposeFilter< 451 ParenMatcher< Fst<Arc> > > > > { 452 public: 453 typedef typename Arc::Label Label; 454 typedef ParenMatcher< Fst<Arc> > PdtMatcher; 455 typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; 456 typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; 457 using COptions::matcher1; 458 using COptions::matcher2; 459 using COptions::filter; 460 461 PdtComposeFstOptions(const Fst<Arc> &ifst1, 462 const Fst<Arc> &ifst2, 463 const vector<pair<Label, Label> > &parens, 464 bool expand = false, bool keep_parens = true) { 465 matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); 466 matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); 467 468 filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, 469 expand, keep_parens); 470 } 471 }; 472 473 enum PdtComposeFilter { 474 PAREN_FILTER, // Bar-Hillel construction; keeps parentheses 475 EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses 476 EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses 477 }; 478 479 struct PdtComposeOptions { 480 bool connect; // Connect output 481 PdtComposeFilter filter_type; // Which pre-defined filter to use 482 483 explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER) 484 : connect(c), filter_type(ft) {} 485 PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {} 486 }; 487 488 // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and 489 // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg). 490 // In the PDTs, some transitions are labeled with open or close 491 // parentheses. To be interpreted as a PDT, the parens must balance on 492 // a path (see PdtExpand()). The open-close parenthesis label pairs 493 // are passed in 'parens'. 494 template <class Arc> 495 void Compose(const Fst<Arc> &ifst1, 496 const vector<pair<typename Arc::Label, 497 typename Arc::Label> > &parens, 498 const Fst<Arc> &ifst2, 499 MutableFst<Arc> *ofst, 500 const PdtComposeOptions &opts = PdtComposeOptions()) { 501 bool expand = opts.filter_type != PAREN_FILTER; 502 bool keep_parens = opts.filter_type != EXPAND_FILTER; 503 PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2, 504 expand, keep_parens); 505 copts.gc_limit = 0; 506 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); 507 if (opts.connect) 508 Connect(ofst); 509 } 510 511 // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as 512 // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg). 513 // In the PDTs, some transitions are labeled with open or close 514 // parentheses. To be interpreted as a PDT, the parens must balance on 515 // a path (see ExpandFst()). The open-close parenthesis label pairs 516 // are passed in 'parens'. 517 template <class Arc> 518 void Compose(const Fst<Arc> &ifst1, 519 const Fst<Arc> &ifst2, 520 const vector<pair<typename Arc::Label, 521 typename Arc::Label> > &parens, 522 MutableFst<Arc> *ofst, 523 const PdtComposeOptions &opts = PdtComposeOptions()) { 524 bool expand = opts.filter_type != PAREN_FILTER; 525 bool keep_parens = opts.filter_type != EXPAND_FILTER; 526 PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens, 527 expand, keep_parens); 528 copts.gc_limit = 0; 529 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); 530 if (opts.connect) 531 Connect(ofst); 532 } 533 534 } // namespace fst 535 536 #endif // FST_EXTENSIONS_PDT_COMPOSE_H__ 537