1 // lookahead-matcher.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Classes to add lookahead to FST matchers, useful e.g. for improving 20 // composition efficiency with certain inputs. 21 22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ 23 #define FST_LIB_LOOKAHEAD_MATCHER_H__ 24 25 #include <fst/add-on.h> 26 #include <fst/const-fst.h> 27 #include <fst/fst.h> 28 #include <fst/label-reachable.h> 29 #include <fst/matcher.h> 30 31 32 DECLARE_string(save_relabel_ipairs); 33 DECLARE_string(save_relabel_opairs); 34 35 namespace fst { 36 37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see 38 // matcher.h) and these additional methods: 39 // 40 // template <class F> 41 // class LookAheadMatcher { 42 // public: 43 // typedef F FST; 44 // typedef F::Arc Arc; 45 // typedef typename Arc::StateId StateId; 46 // typedef typename Arc::Label Label; 47 // typedef typename Arc::Weight Weight; 48 // 49 // // Required constructors. 50 // LookAheadMatcher(const F &fst, MatchType match_type); 51 // // If safe=true, the copy is thread-safe (except the lookahead Fst is 52 // // preserved). See Fst<>::Cop() for further doc. 53 // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); 54 // 55 // Below are methods for looking ahead for a match to a label and 56 // more generally, to a rational set. Each returns false if there is 57 // definitely not a match and returns true if there possibly is a 58 // match. 59 60 // // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state 61 // // after possibly following epsilon transitions? 62 // bool LookAheadLabel(Label label) const; 63 // 64 // // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an 65 // // arbitrary rational set of strings, specified by an FST and a state 66 // // from which to begin the matching. If the lookahead FST is a 67 // // transducer, this looks on the side different from the matcher 68 // // 'match_type' (cf. composition). 69 // 70 // // Are there paths P from 's' in the lookahead FST that can be read from 71 // // the cur. matcher state? 72 // bool LookAheadFst(const Fst<Arc>& fst, StateId s); 73 // 74 // // Gives an estimate of the combined weight of the paths P in the 75 // // lookahead and matcher FSTs for the last call to LookAheadFst. 76 // // A trivial implementation returns Weight::One(). Non-trivial 77 // // implementations are useful for weight-pushing in composition. 78 // Weight LookAheadWeight() const; 79 // 80 // // Is there is a single non-epsilon arc found in the lookahead FST 81 // // that begins P (after possibly following any epsilons) in the last 82 // // call LookAheadFst? If so, return true and copy it to '*arc', o.w. 83 // // return false. A trivial implementation returns false. Non-trivial 84 // // implementations are useful for label-pushing in composition. 85 // bool LookAheadPrefix(Arc *arc); 86 // 87 // // Optionally pre-specifies the lookahead FST that will be passed 88 // // to LookAheadFst() for possible precomputation. If copy is true, 89 // // then 'fst' is a copy of the FST used in the previous call to 90 // // this method (useful to avoid unnecessary updates). 91 // void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false); 92 // 93 // }; 94 95 // 96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): 97 // 98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. 99 const uint32 kInputLookAheadMatcher = 0x00000010; 100 101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. 102 const uint32 kOutputLookAheadMatcher = 0x00000020; 103 104 // A non-trivial implementation of LookAheadWeight() method defined and 105 // should be used? 106 const uint32 kLookAheadWeight = 0x00000040; 107 108 // A non-trivial implementation of LookAheadPrefix() method defined and 109 // should be used? 110 const uint32 kLookAheadPrefix = 0x00000080; 111 112 // Look-ahead of matcher FST non-epsilon arcs? 113 const uint32 kLookAheadNonEpsilons = 0x00000100; 114 115 // Look-ahead of matcher FST epsilon arcs? 116 const uint32 kLookAheadEpsilons = 0x00000200; 117 118 // Ignore epsilon paths for the lookahead prefix? Note this gives 119 // correct results in composition only with an appropriate composition 120 // filter since it depends on the filter blocking the ignored paths. 121 const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; 122 123 // For LabelLookAheadMatcher, save relabeling data to file 124 const uint32 kLookAheadKeepRelabelData = 0x00000800; 125 126 // Flags used for lookahead matchers. 127 const uint32 kLookAheadFlags = 0x00000ff0; 128 129 // LookAhead Matcher interface, templated on the Arc definition; used 130 // for lookahead matcher specializations that are returned by the 131 // InitMatcher() Fst method. 132 template <class A> 133 class LookAheadMatcherBase : public MatcherBase<A> { 134 public: 135 typedef A Arc; 136 typedef typename A::StateId StateId; 137 typedef typename A::Label Label; 138 typedef typename A::Weight Weight; 139 140 LookAheadMatcherBase() 141 : weight_(Weight::One()), 142 prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} 143 144 virtual ~LookAheadMatcherBase() {} 145 146 bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } 147 148 bool LookAheadFst(const Fst<Arc> &fst, StateId s) { 149 return LookAheadFst_(fst, s); 150 } 151 152 Weight LookAheadWeight() const { return weight_; } 153 154 bool LookAheadPrefix(Arc *arc) const { 155 if (prefix_arc_.nextstate != kNoStateId) { 156 *arc = prefix_arc_; 157 return true; 158 } else { 159 return false; 160 } 161 } 162 163 virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0; 164 165 protected: 166 void SetLookAheadWeight(const Weight &w) { weight_ = w; } 167 168 void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } 169 170 void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } 171 172 private: 173 virtual bool LookAheadLabel_(Label label) const = 0; 174 virtual bool LookAheadFst_(const Fst<Arc> &fst, 175 StateId s) = 0; // This must set l.a. weight and 176 // prefix if non-trivial. 177 Weight weight_; // Look-ahead weight 178 Arc prefix_arc_; // Look-ahead prefix arc 179 }; 180 181 182 // Don't really lookahead, just declare future looks good regardless. 183 template <class M> 184 class TrivialLookAheadMatcher 185 : public LookAheadMatcherBase<typename M::FST::Arc> { 186 public: 187 typedef typename M::FST FST; 188 typedef typename M::Arc Arc; 189 typedef typename Arc::StateId StateId; 190 typedef typename Arc::Label Label; 191 typedef typename Arc::Weight Weight; 192 193 TrivialLookAheadMatcher(const FST &fst, MatchType match_type) 194 : matcher_(fst, match_type) {} 195 196 TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher, 197 bool safe = false) 198 : matcher_(lmatcher.matcher_, safe) {} 199 200 // General matcher methods 201 TrivialLookAheadMatcher<M> *Copy(bool safe = false) const { 202 return new TrivialLookAheadMatcher<M>(*this, safe); 203 } 204 205 MatchType Type(bool test) const { return matcher_.Type(test); } 206 void SetState(StateId s) { return matcher_.SetState(s); } 207 bool Find(Label label) { return matcher_.Find(label); } 208 bool Done() const { return matcher_.Done(); } 209 const Arc& Value() const { return matcher_.Value(); } 210 void Next() { matcher_.Next(); } 211 virtual const FST &GetFst() const { return matcher_.GetFst(); } 212 uint64 Properties(uint64 props) const { return matcher_.Properties(props); } 213 uint32 Flags() const { 214 return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; 215 } 216 217 // Look-ahead methods. 218 bool LookAheadLabel(Label label) const { return true; } 219 bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; } 220 Weight LookAheadWeight() const { return Weight::One(); } 221 bool LookAheadPrefix(Arc *arc) const { return false; } 222 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {} 223 224 private: 225 // This allows base class virtual access to non-virtual derived- 226 // class members of the same name. It makes the derived class more 227 // efficient to use but unsafe to further derive. 228 virtual void SetState_(StateId s) { SetState(s); } 229 virtual bool Find_(Label label) { return Find(label); } 230 virtual bool Done_() const { return Done(); } 231 virtual const Arc& Value_() const { return Value(); } 232 virtual void Next_() { Next(); } 233 234 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } 235 236 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { 237 return LookAheadFst(fst, s); 238 } 239 240 Weight LookAheadWeight_() const { return LookAheadWeight(); } 241 bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } 242 243 M matcher_; 244 }; 245 246 // Look-ahead of one transition. Template argument F accepts flags to 247 // control behavior. 248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons | 249 kLookAheadWeight | kLookAheadPrefix> 250 class ArcLookAheadMatcher 251 : public LookAheadMatcherBase<typename M::FST::Arc> { 252 public: 253 typedef typename M::FST FST; 254 typedef typename M::Arc Arc; 255 typedef typename Arc::StateId StateId; 256 typedef typename Arc::Label Label; 257 typedef typename Arc::Weight Weight; 258 typedef NullAddOn MatcherData; 259 260 using LookAheadMatcherBase<Arc>::LookAheadWeight; 261 using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; 262 using LookAheadMatcherBase<Arc>::SetLookAheadWeight; 263 using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; 264 265 ArcLookAheadMatcher(const FST &fst, MatchType match_type, 266 MatcherData *data = 0) 267 : matcher_(fst, match_type), 268 fst_(matcher_.GetFst()), 269 lfst_(0), 270 s_(kNoStateId) {} 271 272 ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher, 273 bool safe = false) 274 : matcher_(lmatcher.matcher_, safe), 275 fst_(matcher_.GetFst()), 276 lfst_(lmatcher.lfst_), 277 s_(kNoStateId) {} 278 279 // General matcher methods 280 ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const { 281 return new ArcLookAheadMatcher<M, F>(*this, safe); 282 } 283 284 MatchType Type(bool test) const { return matcher_.Type(test); } 285 286 void SetState(StateId s) { 287 s_ = s; 288 matcher_.SetState(s); 289 } 290 291 bool Find(Label label) { return matcher_.Find(label); } 292 bool Done() const { return matcher_.Done(); } 293 const Arc& Value() const { return matcher_.Value(); } 294 void Next() { matcher_.Next(); } 295 const FST &GetFst() const { return fst_; } 296 uint64 Properties(uint64 props) const { return matcher_.Properties(props); } 297 uint32 Flags() const { 298 return matcher_.Flags() | kInputLookAheadMatcher | 299 kOutputLookAheadMatcher | F; 300 } 301 302 // Writable matcher methods 303 MatcherData *GetData() const { return 0; } 304 305 // Look-ahead methods. 306 bool LookAheadLabel(Label label) const { return matcher_.Find(label); } 307 308 // Checks if there is a matching (possibly super-final) transition 309 // at (s_, s). 310 bool LookAheadFst(const Fst<Arc> &fst, StateId s); 311 312 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { 313 lfst_ = &fst; 314 } 315 316 private: 317 // This allows base class virtual access to non-virtual derived- 318 // class members of the same name. It makes the derived class more 319 // efficient to use but unsafe to further derive. 320 virtual void SetState_(StateId s) { SetState(s); } 321 virtual bool Find_(Label label) { return Find(label); } 322 virtual bool Done_() const { return Done(); } 323 virtual const Arc& Value_() const { return Value(); } 324 virtual void Next_() { Next(); } 325 326 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } 327 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { 328 return LookAheadFst(fst, s); 329 } 330 331 mutable M matcher_; 332 const FST &fst_; // Matcher FST 333 const Fst<Arc> *lfst_; // Look-ahead FST 334 StateId s_; // Matcher state 335 }; 336 337 template <class M, uint32 F> 338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) { 339 if (&fst != lfst_) 340 InitLookAheadFst(fst); 341 342 bool ret = false; 343 ssize_t nprefix = 0; 344 if (F & kLookAheadWeight) 345 SetLookAheadWeight(Weight::Zero()); 346 if (F & kLookAheadPrefix) 347 ClearLookAheadPrefix(); 348 if (fst_.Final(s_) != Weight::Zero() && 349 lfst_->Final(s) != Weight::Zero()) { 350 if (!(F & (kLookAheadWeight | kLookAheadPrefix))) 351 return true; 352 ++nprefix; 353 if (F & kLookAheadWeight) 354 SetLookAheadWeight(Plus(LookAheadWeight(), 355 Times(fst_.Final(s_), lfst_->Final(s)))); 356 ret = true; 357 } 358 if (matcher_.Find(kNoLabel)) { 359 if (!(F & (kLookAheadWeight | kLookAheadPrefix))) 360 return true; 361 ++nprefix; 362 if (F & kLookAheadWeight) 363 for (; !matcher_.Done(); matcher_.Next()) 364 SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); 365 ret = true; 366 } 367 for (ArcIterator< Fst<Arc> > aiter(*lfst_, s); 368 !aiter.Done(); 369 aiter.Next()) { 370 const Arc &arc = aiter.Value(); 371 Label label = kNoLabel; 372 switch (matcher_.Type(false)) { 373 case MATCH_INPUT: 374 label = arc.olabel; 375 break; 376 case MATCH_OUTPUT: 377 label = arc.ilabel; 378 break; 379 default: 380 FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; 381 return true; 382 } 383 if (label == 0) { 384 if (!(F & (kLookAheadWeight | kLookAheadPrefix))) 385 return true; 386 if (!(F & kLookAheadNonEpsilonPrefix)) 387 ++nprefix; 388 if (F & kLookAheadWeight) 389 SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); 390 ret = true; 391 } else if (matcher_.Find(label)) { 392 if (!(F & (kLookAheadWeight | kLookAheadPrefix))) 393 return true; 394 for (; !matcher_.Done(); matcher_.Next()) { 395 ++nprefix; 396 if (F & kLookAheadWeight) 397 SetLookAheadWeight(Plus(LookAheadWeight(), 398 Times(arc.weight, 399 matcher_.Value().weight))); 400 if ((F & kLookAheadPrefix) && nprefix == 1) 401 SetLookAheadPrefix(arc); 402 } 403 ret = true; 404 } 405 } 406 if (F & kLookAheadPrefix) { 407 if (nprefix == 1) 408 SetLookAheadWeight(Weight::One()); // Avoids double counting. 409 else 410 ClearLookAheadPrefix(); 411 } 412 return ret; 413 } 414 415 416 // Template argument F accepts flags to control behavior. 417 // It must include precisely one of KInputLookAheadMatcher or 418 // KOutputLookAheadMatcher. 419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight | 420 kLookAheadPrefix | kLookAheadNonEpsilonPrefix | 421 kLookAheadKeepRelabelData, 422 class S = DefaultAccumulator<typename M::Arc> > 423 class LabelLookAheadMatcher 424 : public LookAheadMatcherBase<typename M::FST::Arc> { 425 public: 426 typedef typename M::FST FST; 427 typedef typename M::Arc Arc; 428 typedef typename Arc::StateId StateId; 429 typedef typename Arc::Label Label; 430 typedef typename Arc::Weight Weight; 431 typedef LabelReachableData<Label> MatcherData; 432 433 using LookAheadMatcherBase<Arc>::LookAheadWeight; 434 using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; 435 using LookAheadMatcherBase<Arc>::SetLookAheadWeight; 436 using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; 437 438 LabelLookAheadMatcher(const FST &fst, MatchType match_type, 439 MatcherData *data = 0, S *s = 0) 440 : matcher_(fst, match_type), 441 lfst_(0), 442 label_reachable_(0), 443 s_(kNoStateId), 444 error_(false) { 445 if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) { 446 FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F; 447 error_ = true; 448 } 449 bool reach_input = match_type == MATCH_INPUT; 450 if (data) { 451 if (reach_input == data->ReachInput()) 452 label_reachable_ = new LabelReachable<Arc, S>(data, s); 453 } else if ((reach_input && (F & kInputLookAheadMatcher)) || 454 (!reach_input && (F & kOutputLookAheadMatcher))) { 455 label_reachable_ = new LabelReachable<Arc, S>( 456 fst, reach_input, s, F & kLookAheadKeepRelabelData); 457 } 458 } 459 460 LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher, 461 bool safe = false) 462 : matcher_(lmatcher.matcher_, safe), 463 lfst_(lmatcher.lfst_), 464 label_reachable_( 465 lmatcher.label_reachable_ ? 466 new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0), 467 s_(kNoStateId), 468 error_(lmatcher.error_) {} 469 470 ~LabelLookAheadMatcher() { 471 delete label_reachable_; 472 } 473 474 // General matcher methods 475 LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const { 476 return new LabelLookAheadMatcher<M, F, S>(*this, safe); 477 } 478 479 MatchType Type(bool test) const { return matcher_.Type(test); } 480 481 void SetState(StateId s) { 482 if (s_ == s) 483 return; 484 s_ = s; 485 match_set_state_ = false; 486 reach_set_state_ = false; 487 } 488 489 bool Find(Label label) { 490 if (!match_set_state_) { 491 matcher_.SetState(s_); 492 match_set_state_ = true; 493 } 494 return matcher_.Find(label); 495 } 496 497 bool Done() const { return matcher_.Done(); } 498 const Arc& Value() const { return matcher_.Value(); } 499 void Next() { matcher_.Next(); } 500 const FST &GetFst() const { return matcher_.GetFst(); } 501 502 uint64 Properties(uint64 inprops) const { 503 uint64 outprops = matcher_.Properties(inprops); 504 if (error_ || (label_reachable_ && label_reachable_->Error())) 505 outprops |= kError; 506 return outprops; 507 } 508 509 uint32 Flags() const { 510 if (label_reachable_ && label_reachable_->GetData()->ReachInput()) 511 return matcher_.Flags() | F | kInputLookAheadMatcher; 512 else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) 513 return matcher_.Flags() | F | kOutputLookAheadMatcher; 514 else 515 return matcher_.Flags(); 516 } 517 518 // Writable matcher methods 519 MatcherData *GetData() const { 520 return label_reachable_ ? label_reachable_->GetData() : 0; 521 }; 522 523 // Look-ahead methods. 524 bool LookAheadLabel(Label label) const { 525 if (label == 0) 526 return true; 527 528 if (label_reachable_) { 529 if (!reach_set_state_) { 530 label_reachable_->SetState(s_); 531 reach_set_state_ = true; 532 } 533 return label_reachable_->Reach(label); 534 } else { 535 return true; 536 } 537 } 538 539 // Checks if there is a matching (possibly super-final) transition 540 // at (s_, s). 541 template <class L> 542 bool LookAheadFst(const L &fst, StateId s); 543 544 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { 545 lfst_ = &fst; 546 if (label_reachable_) 547 label_reachable_->ReachInit(fst, copy); 548 } 549 550 template <class L> 551 void InitLookAheadFst(const L& fst, bool copy = false) { 552 lfst_ = static_cast<const Fst<Arc> *>(&fst); 553 if (label_reachable_) 554 label_reachable_->ReachInit(fst, copy); 555 } 556 557 private: 558 // This allows base class virtual access to non-virtual derived- 559 // class members of the same name. It makes the derived class more 560 // efficient to use but unsafe to further derive. 561 virtual void SetState_(StateId s) { SetState(s); } 562 virtual bool Find_(Label label) { return Find(label); } 563 virtual bool Done_() const { return Done(); } 564 virtual const Arc& Value_() const { return Value(); } 565 virtual void Next_() { Next(); } 566 567 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } 568 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { 569 return LookAheadFst(fst, s); 570 } 571 572 mutable M matcher_; 573 const Fst<Arc> *lfst_; // Look-ahead FST 574 LabelReachable<Arc, S> *label_reachable_; // Label reachability info 575 StateId s_; // Matcher state 576 bool match_set_state_; // matcher_.SetState called? 577 mutable bool reach_set_state_; // reachable_.SetState called? 578 bool error_; 579 }; 580 581 template <class M, uint32 F, class S> 582 template <class L> inline 583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { 584 if (static_cast<const Fst<Arc> *>(&fst) != lfst_) 585 InitLookAheadFst(fst); 586 587 SetLookAheadWeight(Weight::One()); 588 ClearLookAheadPrefix(); 589 590 if (!label_reachable_) 591 return true; 592 593 label_reachable_->SetState(s_, s); 594 reach_set_state_ = true; 595 596 bool compute_weight = F & kLookAheadWeight; 597 bool compute_prefix = F & kLookAheadPrefix; 598 599 bool reach_input = Type(false) == MATCH_OUTPUT; 600 ArcIterator<L> aiter(fst, s); 601 bool reach_arc = label_reachable_->Reach(&aiter, 0, 602 internal::NumArcs(*lfst_, s), 603 reach_input, compute_weight); 604 Weight lfinal = internal::Final(*lfst_, s); 605 bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal(); 606 if (reach_arc) { 607 ssize_t begin = label_reachable_->ReachBegin(); 608 ssize_t end = label_reachable_->ReachEnd(); 609 if (compute_prefix && end - begin == 1 && !reach_final) { 610 aiter.Seek(begin); 611 SetLookAheadPrefix(aiter.Value()); 612 compute_weight = false; 613 } else if (compute_weight) { 614 SetLookAheadWeight(label_reachable_->ReachWeight()); 615 } 616 } 617 if (reach_final && compute_weight) 618 SetLookAheadWeight(reach_arc ? 619 Plus(LookAheadWeight(), lfinal) : lfinal); 620 621 return reach_arc || reach_final; 622 } 623 624 625 // Label-lookahead relabeling class. 626 template <class A> 627 class LabelLookAheadRelabeler { 628 public: 629 typedef typename A::Label Label; 630 typedef LabelReachableData<Label> MatcherData; 631 typedef AddOnPair<MatcherData, MatcherData> D; 632 633 // Relabels matcher Fst - initialization function object. 634 template <typename I> 635 LabelLookAheadRelabeler(I **impl); 636 637 // Relabels arbitrary Fst. Class L should be a label-lookahead Fst. 638 template <class L> 639 static void Relabel(MutableFst<A> *fst, const L &mfst, 640 bool relabel_input) { 641 typename L::Impl *impl = mfst.GetImpl(); 642 D *data = impl->GetAddOn(); 643 LabelReachable<A> reachable(data->First() ? 644 data->First() : data->Second()); 645 reachable.Relabel(fst, relabel_input); 646 } 647 648 // Returns relabeling pairs (cf. relabel.h::Relabel()). 649 // Class L should be a label-lookahead Fst. 650 // If 'avoid_collisions' is true, extra pairs are added to 651 // ensure no collisions when relabeling automata that have 652 // labels unseen here. 653 template <class L> 654 static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs, 655 bool avoid_collisions = false) { 656 typename L::Impl *impl = mfst.GetImpl(); 657 D *data = impl->GetAddOn(); 658 LabelReachable<A> reachable(data->First() ? 659 data->First() : data->Second()); 660 reachable.RelabelPairs(pairs, avoid_collisions); 661 } 662 }; 663 664 template <class A> 665 template <typename I> inline 666 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) { 667 Fst<A> &fst = (*impl)->GetFst(); 668 D *data = (*impl)->GetAddOn(); 669 const string name = (*impl)->Type(); 670 bool is_mutable = fst.Properties(kMutable, false); 671 MutableFst<A> *mfst = 0; 672 if (is_mutable) { 673 mfst = static_cast<MutableFst<A> *>(&fst); 674 } else { 675 mfst = new VectorFst<A>(fst); 676 data->IncrRefCount(); 677 delete *impl; 678 } 679 if (data->First()) { // reach_input 680 LabelReachable<A> reachable(data->First()); 681 reachable.Relabel(mfst, true); 682 if (!FLAGS_save_relabel_ipairs.empty()) { 683 vector<pair<Label, Label> > pairs; 684 reachable.RelabelPairs(&pairs, true); 685 WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); 686 } 687 } else { 688 LabelReachable<A> reachable(data->Second()); 689 reachable.Relabel(mfst, false); 690 if (!FLAGS_save_relabel_opairs.empty()) { 691 vector<pair<Label, Label> > pairs; 692 reachable.RelabelPairs(&pairs, true); 693 WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); 694 } 695 } 696 if (!is_mutable) { 697 *impl = new I(*mfst, name); 698 (*impl)->SetAddOn(data); 699 delete mfst; 700 data->DecrRefCount(); 701 } 702 } 703 704 705 // Generic lookahead matcher, templated on the FST definition 706 // - a wrapper around pointer to specific one. 707 template <class F> 708 class LookAheadMatcher { 709 public: 710 typedef F FST; 711 typedef typename F::Arc Arc; 712 typedef typename Arc::StateId StateId; 713 typedef typename Arc::Label Label; 714 typedef typename Arc::Weight Weight; 715 typedef LookAheadMatcherBase<Arc> LBase; 716 717 LookAheadMatcher(const F &fst, MatchType match_type) { 718 base_ = fst.InitMatcher(match_type); 719 if (!base_) 720 base_ = new SortedMatcher<F>(fst, match_type); 721 lookahead_ = false; 722 } 723 724 LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) { 725 base_ = matcher.base_->Copy(safe); 726 lookahead_ = matcher.lookahead_; 727 } 728 729 ~LookAheadMatcher() { delete base_; } 730 731 // General matcher methods 732 LookAheadMatcher<F> *Copy(bool safe = false) const { 733 return new LookAheadMatcher<F>(*this, safe); 734 } 735 736 MatchType Type(bool test) const { return base_->Type(test); } 737 void SetState(StateId s) { base_->SetState(s); } 738 bool Find(Label label) { return base_->Find(label); } 739 bool Done() const { return base_->Done(); } 740 const Arc& Value() const { return base_->Value(); } 741 void Next() { base_->Next(); } 742 const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } 743 744 uint64 Properties(uint64 props) const { return base_->Properties(props); } 745 746 uint32 Flags() const { return base_->Flags(); } 747 748 // Look-ahead methods 749 bool LookAheadLabel(Label label) const { 750 if (LookAheadCheck()) { 751 LBase *lbase = static_cast<LBase *>(base_); 752 return lbase->LookAheadLabel(label); 753 } else { 754 return true; 755 } 756 } 757 758 bool LookAheadFst(const Fst<Arc> &fst, StateId s) { 759 if (LookAheadCheck()) { 760 LBase *lbase = static_cast<LBase *>(base_); 761 return lbase->LookAheadFst(fst, s); 762 } else { 763 return true; 764 } 765 } 766 767 Weight LookAheadWeight() const { 768 if (LookAheadCheck()) { 769 LBase *lbase = static_cast<LBase *>(base_); 770 return lbase->LookAheadWeight(); 771 } else { 772 return Weight::One(); 773 } 774 } 775 776 bool LookAheadPrefix(Arc *arc) const { 777 if (LookAheadCheck()) { 778 LBase *lbase = static_cast<LBase *>(base_); 779 return lbase->LookAheadPrefix(arc); 780 } else { 781 return false; 782 } 783 } 784 785 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { 786 if (LookAheadCheck()) { 787 LBase *lbase = static_cast<LBase *>(base_); 788 lbase->InitLookAheadFst(fst, copy); 789 } 790 } 791 792 private: 793 bool LookAheadCheck() const { 794 if (!lookahead_) { 795 lookahead_ = base_->Flags() & 796 (kInputLookAheadMatcher | kOutputLookAheadMatcher); 797 if (!lookahead_) { 798 FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; 799 } 800 } 801 return lookahead_; 802 } 803 804 MatcherBase<Arc> *base_; 805 mutable bool lookahead_; 806 807 void operator=(const LookAheadMatcher<Arc> &); // disallow 808 }; 809 810 } // namespace fst 811 812 #endif // FST_LIB_LOOKAHEAD_MATCHER_H__ 813