Home | History | Annotate | Download | only in pdt
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Author: jpr (at) google.com (Jake Ratkiewicz)
     16 // Convenience file for including all PDT operations at once, and/or
     17 // registering them for new arc types.
     18 
     19 #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
     20 #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
     21 
     22 #include <utility>
     23 using std::pair; using std::make_pair;
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/compose.h>  // for ComposeOptions
     28 #include <fst/util.h>
     29 
     30 #include <fst/script/fst-class.h>
     31 #include <fst/script/arg-packs.h>
     32 #include <fst/script/shortest-path.h>
     33 
     34 #include <fst/extensions/pdt/compose.h>
     35 #include <fst/extensions/pdt/expand.h>
     36 #include <fst/extensions/pdt/info.h>
     37 #include <fst/extensions/pdt/replace.h>
     38 #include <fst/extensions/pdt/reverse.h>
     39 #include <fst/extensions/pdt/shortest-path.h>
     40 
     41 
     42 namespace fst {
     43 namespace script {
     44 
     45 // PDT COMPOSE
     46 
     47 typedef args::Package<const FstClass &,
     48                       const FstClass &,
     49                       const vector<pair<int64, int64> >&,
     50                       MutableFstClass *,
     51                       const PdtComposeOptions &,
     52                       bool> PdtComposeArgs;
     53 
     54 template<class Arc>
     55 void PdtCompose(PdtComposeArgs *args) {
     56   const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
     57   const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
     58   MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
     59 
     60   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
     61       args->arg3.size());
     62 
     63   for (size_t i = 0; i < parens.size(); ++i) {
     64     parens[i].first = args->arg3[i].first;
     65     parens[i].second = args->arg3[i].second;
     66   }
     67 
     68   if (args->arg6) {
     69     Compose(ifst1, parens, ifst2, ofst, args->arg5);
     70   } else {
     71     Compose(ifst1, ifst2, parens, ofst, args->arg5);
     72   }
     73 }
     74 
     75 void PdtCompose(const FstClass & ifst1,
     76                 const FstClass & ifst2,
     77                 const vector<pair<int64, int64> > &parens,
     78                 MutableFstClass *ofst,
     79                 const PdtComposeOptions &copts,
     80                 bool left_pdt);
     81 
     82 // PDT EXPAND
     83 
     84 struct PdtExpandOptions {
     85   bool connect;
     86   bool keep_parentheses;
     87   WeightClass weight_threshold;
     88 
     89   PdtExpandOptions(bool c = true, bool k = false,
     90                    WeightClass w = WeightClass::Zero())
     91       : connect(c), keep_parentheses(k), weight_threshold(w) {}
     92 };
     93 
     94 typedef args::Package<const FstClass &,
     95                       const vector<pair<int64, int64> >&,
     96                       MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
     97 
     98 template<class Arc>
     99 void PdtExpand(PdtExpandArgs *args) {
    100   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    101   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
    102 
    103   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
    104       args->arg2.size());
    105   for (size_t i = 0; i < parens.size(); ++i) {
    106     parens[i].first = args->arg2[i].first;
    107     parens[i].second = args->arg2[i].second;
    108   }
    109   Expand(fst, parens, ofst,
    110          ExpandOptions<Arc>(
    111              args->arg4.connect, args->arg4.keep_parentheses,
    112              *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
    113 }
    114 
    115 void PdtExpand(const FstClass &ifst,
    116                const vector<pair<int64, int64> > &parens,
    117                MutableFstClass *ofst, const PdtExpandOptions &opts);
    118 
    119 void PdtExpand(const FstClass &ifst,
    120                const vector<pair<int64, int64> > &parens,
    121                MutableFstClass *ofst, bool connect);
    122 
    123 // PDT REPLACE
    124 
    125 typedef args::Package<const vector<pair<int64, const FstClass*> > &,
    126                       MutableFstClass *,
    127                       vector<pair<int64, int64> > *,
    128                       const int64 &> PdtReplaceArgs;
    129 template<class Arc>
    130 void PdtReplace(PdtReplaceArgs *args) {
    131   vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
    132       args->arg1.size());
    133   for (size_t i = 0; i < tuples.size(); ++i) {
    134     tuples[i].first = args->arg1[i].first;
    135     tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
    136   }
    137   MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
    138   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
    139       args->arg3->size());
    140 
    141   for (size_t i = 0; i < parens.size(); ++i) {
    142     parens[i].first = args->arg3->at(i).first;
    143     parens[i].second = args->arg3->at(i).second;
    144   }
    145   Replace(tuples, ofst, &parens, args->arg4);
    146 
    147   // now copy parens back
    148   args->arg3->resize(parens.size());
    149   for (size_t i = 0; i < parens.size(); ++i) {
    150     (*args->arg3)[i].first = parens[i].first;
    151     (*args->arg3)[i].second = parens[i].second;
    152   }
    153 }
    154 
    155 void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
    156                 MutableFstClass *ofst,
    157                 vector<pair<int64, int64> > *parens,
    158                 const int64 &root);
    159 
    160 // PDT REVERSE
    161 
    162 typedef args::Package<const FstClass &,
    163                       const vector<pair<int64, int64> >&,
    164                       MutableFstClass *> PdtReverseArgs;
    165 
    166 template<class Arc>
    167 void PdtReverse(PdtReverseArgs *args) {
    168   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    169   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
    170 
    171   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
    172       args->arg2.size());
    173   for (size_t i = 0; i < parens.size(); ++i) {
    174     parens[i].first = args->arg2[i].first;
    175     parens[i].second = args->arg2[i].second;
    176   }
    177   Reverse(fst, parens, ofst);
    178 }
    179 
    180 void PdtReverse(const FstClass &ifst,
    181                 const vector<pair<int64, int64> > &parens,
    182                 MutableFstClass *ofst);
    183 
    184 
    185 // PDT SHORTESTPATH
    186 
    187 struct PdtShortestPathOptions {
    188   QueueType queue_type;
    189   bool keep_parentheses;
    190   bool path_gc;
    191 
    192   PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
    193                          bool kp = false, bool gc = true)
    194       : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
    195 };
    196 
    197 typedef args::Package<const FstClass &,
    198                       const vector<pair<int64, int64> >&,
    199                       MutableFstClass *,
    200                       const PdtShortestPathOptions &> PdtShortestPathArgs;
    201 
    202 template<class Arc>
    203 void PdtShortestPath(PdtShortestPathArgs *args) {
    204   typedef typename Arc::StateId StateId;
    205   typedef typename Arc::Label Label;
    206   typedef typename Arc::Weight Weight;
    207 
    208   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    209   MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
    210   const PdtShortestPathOptions &opts = args->arg4;
    211 
    212 
    213   vector<pair<Label, Label> > parens(args->arg2.size());
    214   for (size_t i = 0; i < parens.size(); ++i) {
    215     parens[i].first = args->arg2[i].first;
    216     parens[i].second = args->arg2[i].second;
    217   }
    218 
    219   switch (opts.queue_type) {
    220     default:
    221       FSTERROR() << "Unknown queue type: " << opts.queue_type;
    222     case FIFO_QUEUE: {
    223       typedef FifoQueue<StateId> Queue;
    224       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
    225                                                          opts.path_gc);
    226       ShortestPath(fst, parens, ofst, spopts);
    227       return;
    228     }
    229     case LIFO_QUEUE: {
    230       typedef LifoQueue<StateId> Queue;
    231       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
    232                                                          opts.path_gc);
    233       ShortestPath(fst, parens, ofst, spopts);
    234       return;
    235     }
    236     case STATE_ORDER_QUEUE: {
    237       typedef StateOrderQueue<StateId> Queue;
    238       fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
    239                                                          opts.path_gc);
    240       ShortestPath(fst, parens, ofst, spopts);
    241       return;
    242     }
    243   }
    244 }
    245 
    246 void PdtShortestPath(const FstClass &ifst,
    247                      const vector<pair<int64, int64> > &parens,
    248                      MutableFstClass *ofst,
    249                      const PdtShortestPathOptions &opts =
    250                      PdtShortestPathOptions());
    251 
    252 // PRINT INFO
    253 
    254 typedef args::Package<const FstClass &,
    255                       const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
    256 
    257 template<class Arc>
    258 void PrintPdtInfo(PrintPdtInfoArgs *args) {
    259   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    260   vector<pair<typename Arc::Label, typename Arc::Label> > parens(
    261       args->arg2.size());
    262   for (size_t i = 0; i < parens.size(); ++i) {
    263     parens[i].first = args->arg2[i].first;
    264     parens[i].second = args->arg2[i].second;
    265   }
    266   PdtInfo<Arc> pdtinfo(fst, parens);
    267   PrintPdtInfo(pdtinfo);
    268 }
    269 
    270 void PrintPdtInfo(const FstClass &ifst,
    271                   const vector<pair<int64, int64> > &parens);
    272 
    273 }  // namespace script
    274 }  // namespace fst
    275 
    276 
    277 #define REGISTER_FST_PDT_OPERATIONS(ArcType)                                \
    278   REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs);              \
    279   REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs);                \
    280   REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs);              \
    281   REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs);              \
    282   REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs);    \
    283   REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
    284 #endif  // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
    285