Home | History | Annotate | Download | only in pdt
      1 // compose.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Compose a PDT and an FST.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
     22 #define FST_EXTENSIONS_PDT_COMPOSE_H__
     23 
     24 #include <list>
     25 
     26 #include <fst/extensions/pdt/pdt.h>
     27 #include <fst/compose.h>
     28 
     29 namespace fst {
     30 
     31 // Return paren arcs for Find(kNoLabel).
     32 const uint32 kParenList =  0x00000001;
     33 
     34 // Return a kNolabel loop for Find(paren).
     35 const uint32 kParenLoop =  0x00000002;
     36 
     37 // This class is a matcher that treats parens as multi-epsilon labels.
     38 // It is most efficient if the parens are in a range non-overlapping with
     39 // the non-paren labels.
     40 template <class F>
     41 class ParenMatcher {
     42  public:
     43   typedef SortedMatcher<F> M;
     44   typedef typename M::FST FST;
     45   typedef typename M::Arc Arc;
     46   typedef typename Arc::StateId StateId;
     47   typedef typename Arc::Label Label;
     48   typedef typename Arc::Weight Weight;
     49 
     50   ParenMatcher(const FST &fst, MatchType match_type,
     51                uint32 flags = (kParenLoop | kParenList))
     52       : matcher_(fst, match_type),
     53         match_type_(match_type),
     54         flags_(flags) {
     55     if (match_type == MATCH_INPUT) {
     56       loop_.ilabel = kNoLabel;
     57       loop_.olabel = 0;
     58     } else {
     59       loop_.ilabel = 0;
     60       loop_.olabel = kNoLabel;
     61     }
     62     loop_.weight = Weight::One();
     63     loop_.nextstate = kNoStateId;
     64   }
     65 
     66   ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
     67       : matcher_(matcher.matcher_, safe),
     68         match_type_(matcher.match_type_),
     69         flags_(matcher.flags_),
     70         open_parens_(matcher.open_parens_),
     71         close_parens_(matcher.close_parens_),
     72         loop_(matcher.loop_) {
     73     loop_.nextstate = kNoStateId;
     74   }
     75 
     76   ParenMatcher<F> *Copy(bool safe = false) const {
     77     return new ParenMatcher<F>(*this, safe);
     78   }
     79 
     80   MatchType Type(bool test) const { return matcher_.Type(test); }
     81 
     82   void SetState(StateId s) {
     83     matcher_.SetState(s);
     84     loop_.nextstate = s;
     85   }
     86 
     87   bool Find(Label match_label);
     88 
     89   bool Done() const {
     90     return done_;
     91   }
     92 
     93   const Arc& Value() const {
     94     return paren_loop_ ? loop_ : matcher_.Value();
     95   }
     96 
     97   void Next();
     98 
     99   const FST &GetFst() const { return matcher_.GetFst(); }
    100 
    101   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
    102 
    103   uint32 Flags() const { return matcher_.Flags(); }
    104 
    105   void AddOpenParen(Label label) {
    106     if (label == 0) {
    107       FSTERROR() << "ParenMatcher: Bad open paren label: 0";
    108     } else {
    109       open_parens_.Insert(label);
    110     }
    111   }
    112 
    113   void AddCloseParen(Label label) {
    114     if (label == 0) {
    115       FSTERROR() << "ParenMatcher: Bad close paren label: 0";
    116     } else {
    117       close_parens_.Insert(label);
    118     }
    119   }
    120 
    121   void RemoveOpenParen(Label label) {
    122     if (label == 0) {
    123       FSTERROR() << "ParenMatcher: Bad open paren label: 0";
    124     } else {
    125       open_parens_.Erase(label);
    126     }
    127   }
    128 
    129   void RemoveCloseParen(Label label) {
    130     if (label == 0) {
    131       FSTERROR() << "ParenMatcher: Bad close paren label: 0";
    132     } else {
    133       close_parens_.Erase(label);
    134     }
    135   }
    136 
    137   void ClearOpenParens() {
    138     open_parens_.Clear();
    139   }
    140 
    141   void ClearCloseParens() {
    142     close_parens_.Clear();
    143   }
    144 
    145   bool IsOpenParen(Label label) const {
    146     return open_parens_.Member(label);
    147   }
    148 
    149   bool IsCloseParen(Label label) const {
    150     return close_parens_.Member(label);
    151   }
    152 
    153  private:
    154   // Advances matcher to next open paren if it exists, returning true.
    155   // O.w. returns false.
    156   bool NextOpenParen();
    157 
    158   // Advances matcher to next open paren if it exists, returning true.
    159   // O.w. returns false.
    160   bool NextCloseParen();
    161 
    162   M matcher_;
    163   MatchType match_type_;          // Type of match to perform
    164   uint32 flags_;
    165 
    166   // open paren label set
    167   CompactSet<Label, kNoLabel> open_parens_;
    168 
    169   // close paren label set
    170   CompactSet<Label, kNoLabel> close_parens_;
    171 
    172 
    173   bool open_paren_list_;         // Matching open paren list
    174   bool close_paren_list_;        // Matching close paren list
    175   bool paren_loop_;              // Current arc is the implicit paren loop
    176   mutable Arc loop_;             // For non-consuming symbols
    177   bool done_;                    // Matching done
    178 
    179   void operator=(const ParenMatcher<F> &);  // Disallow
    180 };
    181 
    182 template <class M> inline
    183 bool ParenMatcher<M>::Find(Label match_label) {
    184   open_paren_list_ = false;
    185   close_paren_list_ = false;
    186   paren_loop_ = false;
    187   done_ = false;
    188 
    189   // Returns all parenthesis arcs
    190   if (match_label == kNoLabel && (flags_ & kParenList)) {
    191     if (open_parens_.LowerBound() != kNoLabel) {
    192       matcher_.LowerBound(open_parens_.LowerBound());
    193       open_paren_list_ = NextOpenParen();
    194       if (open_paren_list_) return true;
    195     }
    196     if (close_parens_.LowerBound() != kNoLabel) {
    197       matcher_.LowerBound(close_parens_.LowerBound());
    198       close_paren_list_ = NextCloseParen();
    199       if (close_paren_list_) return true;
    200     }
    201   }
    202 
    203   // Returns 'implicit' paren loop
    204   if (match_label > 0 && (flags_ & kParenLoop) &&
    205       (IsOpenParen(match_label) || IsCloseParen(match_label))) {
    206     paren_loop_ = true;
    207     return true;
    208   }
    209 
    210   // Returns all other labels
    211   if (matcher_.Find(match_label))
    212     return true;
    213 
    214   done_ = true;
    215   return false;
    216 }
    217 
    218 template <class F> inline
    219 void ParenMatcher<F>::Next() {
    220   if (paren_loop_) {
    221     paren_loop_ = false;
    222     done_ = true;
    223   } else if (open_paren_list_) {
    224     matcher_.Next();
    225     open_paren_list_ = NextOpenParen();
    226     if (open_paren_list_) return;
    227 
    228     if (close_parens_.LowerBound() != kNoLabel) {
    229       matcher_.LowerBound(close_parens_.LowerBound());
    230       close_paren_list_ = NextCloseParen();
    231       if (close_paren_list_) return;
    232     }
    233     done_ = !matcher_.Find(kNoLabel);
    234   } else if (close_paren_list_) {
    235     matcher_.Next();
    236     close_paren_list_ = NextCloseParen();
    237     if (close_paren_list_) return;
    238     done_ = !matcher_.Find(kNoLabel);
    239   } else {
    240     matcher_.Next();
    241     done_ = matcher_.Done();
    242   }
    243 }
    244 
    245 // Advances matcher to next open paren if it exists, returning true.
    246 // O.w. returns false.
    247 template <class F> inline
    248 bool ParenMatcher<F>::NextOpenParen() {
    249   for (; !matcher_.Done(); matcher_.Next()) {
    250     Label label = match_type_ == MATCH_INPUT ?
    251         matcher_.Value().ilabel : matcher_.Value().olabel;
    252     if (label > open_parens_.UpperBound())
    253       return false;
    254     if (IsOpenParen(label))
    255       return true;
    256   }
    257   return false;
    258 }
    259 
    260 // Advances matcher to next close paren if it exists, returning true.
    261 // O.w. returns false.
    262 template <class F> inline
    263 bool ParenMatcher<F>::NextCloseParen() {
    264   for (; !matcher_.Done(); matcher_.Next()) {
    265     Label label = match_type_ == MATCH_INPUT ?
    266         matcher_.Value().ilabel : matcher_.Value().olabel;
    267     if (label > close_parens_.UpperBound())
    268       return false;
    269     if (IsCloseParen(label))
    270       return true;
    271   }
    272   return false;
    273 }
    274 
    275 
    276 template <class F>
    277 class ParenFilter {
    278  public:
    279   typedef typename F::FST1 FST1;
    280   typedef typename F::FST2 FST2;
    281   typedef typename F::Arc Arc;
    282   typedef typename Arc::StateId StateId;
    283   typedef typename Arc::Label Label;
    284   typedef typename Arc::Weight Weight;
    285   typedef typename F::Matcher1 Matcher1;
    286   typedef typename F::Matcher2 Matcher2;
    287   typedef typename F::FilterState FilterState1;
    288   typedef StateId StackId;
    289   typedef PdtStack<StackId, Label> ParenStack;
    290   typedef IntegerFilterState<StackId> FilterState2;
    291   typedef PairFilterState<FilterState1, FilterState2> FilterState;
    292   typedef ParenFilter<F> Filter;
    293 
    294   ParenFilter(const FST1 &fst1, const FST2 &fst2,
    295               Matcher1 *matcher1 = 0,  Matcher2 *matcher2 = 0,
    296               const vector<pair<Label, Label> > *parens = 0,
    297               bool expand = false, bool keep_parens = true)
    298       : filter_(fst1, fst2, matcher1, matcher2),
    299         parens_(parens ? *parens : vector<pair<Label, Label> >()),
    300         expand_(expand),
    301         keep_parens_(keep_parens),
    302         f_(FilterState::NoState()),
    303         stack_(parens_),
    304         paren_id_(-1) {
    305     if (parens) {
    306       for (size_t i = 0; i < parens->size(); ++i) {
    307         const pair<Label, Label>  &p = (*parens)[i];
    308         parens_.push_back(p);
    309         GetMatcher1()->AddOpenParen(p.first);
    310         GetMatcher2()->AddOpenParen(p.first);
    311         if (!expand_) {
    312           GetMatcher1()->AddCloseParen(p.second);
    313           GetMatcher2()->AddCloseParen(p.second);
    314         }
    315       }
    316     }
    317   }
    318 
    319   ParenFilter(const Filter &filter, bool safe = false)
    320       : filter_(filter.filter_, safe),
    321         parens_(filter.parens_),
    322         expand_(filter.expand_),
    323         keep_parens_(filter.keep_parens_),
    324         f_(FilterState::NoState()),
    325         stack_(filter.parens_),
    326         paren_id_(-1) { }
    327 
    328   FilterState Start() const {
    329     return FilterState(filter_.Start(), FilterState2(0));
    330   }
    331 
    332   void SetState(StateId s1, StateId s2, const FilterState &f) {
    333     f_ = f;
    334     filter_.SetState(s1, s2, f_.GetState1());
    335     if (!expand_)
    336       return;
    337 
    338     ssize_t paren_id = stack_.Top(f.GetState2().GetState());
    339     if (paren_id != paren_id_) {
    340       if (paren_id_ != -1) {
    341         GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
    342         GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
    343       }
    344       paren_id_ = paren_id;
    345       if (paren_id_ != -1) {
    346         GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
    347         GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
    348       }
    349     }
    350   }
    351 
    352   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    353     FilterState1 f1 = filter_.FilterArc(arc1, arc2);
    354     const FilterState2 &f2 = f_.GetState2();
    355     if (f1 == FilterState1::NoState())
    356       return FilterState::NoState();
    357 
    358     if (arc1->olabel == kNoLabel && arc2->ilabel) {         // arc2 parentheses
    359       if (keep_parens_) {
    360         arc1->ilabel = arc2->ilabel;
    361       } else if (arc2->ilabel) {
    362         arc2->olabel = arc1->ilabel;
    363       }
    364       return FilterParen(arc2->ilabel, f1, f2);
    365     } else if (arc2->ilabel == kNoLabel && arc1->olabel) {  // arc1 parentheses
    366       if (keep_parens_) {
    367         arc2->olabel = arc1->olabel;
    368       } else {
    369         arc1->ilabel = arc2->olabel;
    370       }
    371       return FilterParen(arc1->olabel, f1, f2);
    372     } else {
    373       return FilterState(f1, f2);
    374     }
    375   }
    376 
    377   void FilterFinal(Weight *w1, Weight *w2) const {
    378     if (f_.GetState2().GetState() != 0)
    379       *w1 = Weight::Zero();
    380     filter_.FilterFinal(w1, w2);
    381   }
    382 
    383   // Return resp matchers. Ownership stays with filter.
    384   Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
    385   Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
    386 
    387   uint64 Properties(uint64 iprops) const {
    388     uint64 oprops = filter_.Properties(iprops);
    389     return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
    390   }
    391 
    392  private:
    393   const FilterState FilterParen(Label label, const FilterState1 &f1,
    394                                 const FilterState2 &f2) const {
    395     if (!expand_)
    396       return FilterState(f1, f2);
    397 
    398     StackId stack_id = stack_.Find(f2.GetState(), label);
    399     if (stack_id < 0) {
    400       return FilterState::NoState();
    401     } else {
    402       return FilterState(f1, FilterState2(stack_id));
    403     }
    404   }
    405 
    406   F filter_;
    407   vector<pair<Label, Label> > parens_;
    408   bool expand_;                    // Expands to FST
    409   bool keep_parens_;               // Retains parentheses in output
    410   FilterState f_;                  // Current filter state
    411   mutable ParenStack stack_;
    412   ssize_t paren_id_;
    413 };
    414 
    415 // Class to setup composition options for PDT composition.
    416 // Default is for the PDT as the first composition argument.
    417 template <class Arc, bool left_pdt = true>
    418 class PdtComposeFstOptions : public
    419 ComposeFstOptions<Arc,
    420                   ParenMatcher< Fst<Arc> >,
    421                   ParenFilter<AltSequenceComposeFilter<
    422                                 ParenMatcher< Fst<Arc> > > > > {
    423  public:
    424   typedef typename Arc::Label Label;
    425   typedef ParenMatcher< Fst<Arc> > PdtMatcher;
    426   typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
    427   typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
    428   using COptions::matcher1;
    429   using COptions::matcher2;
    430   using COptions::filter;
    431 
    432   PdtComposeFstOptions(const Fst<Arc> &ifst1,
    433                     const vector<pair<Label, Label> > &parens,
    434                        const Fst<Arc> &ifst2, bool expand = false,
    435                        bool keep_parens = true) {
    436     matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
    437     matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
    438 
    439     filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
    440                            expand, keep_parens);
    441   }
    442 };
    443 
    444 // Class to setup composition options for PDT with FST composition.
    445 // Specialization is for the FST as the first composition argument.
    446 template <class Arc>
    447 class PdtComposeFstOptions<Arc, false> : public
    448 ComposeFstOptions<Arc,
    449                   ParenMatcher< Fst<Arc> >,
    450                   ParenFilter<SequenceComposeFilter<
    451                                 ParenMatcher< Fst<Arc> > > > > {
    452  public:
    453   typedef typename Arc::Label Label;
    454   typedef ParenMatcher< Fst<Arc> > PdtMatcher;
    455   typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
    456   typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
    457   using COptions::matcher1;
    458   using COptions::matcher2;
    459   using COptions::filter;
    460 
    461   PdtComposeFstOptions(const Fst<Arc> &ifst1,
    462                        const Fst<Arc> &ifst2,
    463                        const vector<pair<Label, Label> > &parens,
    464                        bool expand = false, bool keep_parens = true) {
    465     matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
    466     matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
    467 
    468     filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
    469                            expand, keep_parens);
    470   }
    471 };
    472 
    473 enum PdtComposeFilter {
    474   PAREN_FILTER,          // Bar-Hillel construction; keeps parentheses
    475   EXPAND_FILTER,         // Bar-Hillel + expansion; removes parentheses
    476   EXPAND_PAREN_FILTER,   // Bar-Hillel + expansion; keeps parentheses
    477 };
    478 
    479 struct PdtComposeOptions {
    480   bool connect;  // Connect output
    481   PdtComposeFilter filter_type;  // Which pre-defined filter to use
    482 
    483   explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
    484       : connect(c), filter_type(ft) {}
    485   PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
    486 };
    487 
    488 // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
    489 // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
    490 // In the PDTs, some transitions are labeled with open or close
    491 // parentheses. To be interpreted as a PDT, the parens must balance on
    492 // a path (see PdtExpand()). The open-close parenthesis label pairs
    493 // are passed in 'parens'.
    494 template <class Arc>
    495 void Compose(const Fst<Arc> &ifst1,
    496              const vector<pair<typename Arc::Label,
    497                                typename Arc::Label> > &parens,
    498              const Fst<Arc> &ifst2,
    499              MutableFst<Arc> *ofst,
    500              const PdtComposeOptions &opts = PdtComposeOptions()) {
    501   bool expand = opts.filter_type != PAREN_FILTER;
    502   bool keep_parens = opts.filter_type != EXPAND_FILTER;
    503   PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
    504                                         expand, keep_parens);
    505   copts.gc_limit = 0;
    506   *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
    507   if (opts.connect)
    508     Connect(ofst);
    509 }
    510 
    511 // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
    512 // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
    513 // In the PDTs, some transitions are labeled with open or close
    514 // parentheses. To be interpreted as a PDT, the parens must balance on
    515 // a path (see ExpandFst()). The open-close parenthesis label pairs
    516 // are passed in 'parens'.
    517 template <class Arc>
    518 void Compose(const Fst<Arc> &ifst1,
    519              const Fst<Arc> &ifst2,
    520              const vector<pair<typename Arc::Label,
    521                                typename Arc::Label> > &parens,
    522              MutableFst<Arc> *ofst,
    523              const PdtComposeOptions &opts = PdtComposeOptions()) {
    524   bool expand = opts.filter_type != PAREN_FILTER;
    525   bool keep_parens = opts.filter_type != EXPAND_FILTER;
    526   PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
    527                                          expand, keep_parens);
    528   copts.gc_limit = 0;
    529   *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
    530   if (opts.connect)
    531     Connect(ofst);
    532 }
    533 
    534 }  // namespace fst
    535 
    536 #endif  // FST_EXTENSIONS_PDT_COMPOSE_H__
    537