Home | History | Annotate | Download | only in script
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Author: jpr (at) google.com (Jake Ratkiewicz)
     16 
     17 #ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
     18 #define FST_SCRIPT_SHORTEST_DISTANCE_H_
     19 
     20 #include <vector>
     21 using std::vector;
     22 
     23 #include <fst/script/arg-packs.h>
     24 #include <fst/script/fst-class.h>
     25 #include <fst/script/weight-class.h>
     26 #include <fst/script/prune.h>  // for ArcFilterType
     27 #include <fst/queue.h>  // for QueueType
     28 #include <fst/shortest-distance.h>
     29 
     30 namespace fst {
     31 namespace script {
     32 
     33 enum ArcFilterType { ANY_ARC_FILTER, EPSILON_ARC_FILTER,
     34                      INPUT_EPSILON_ARC_FILTER, OUTPUT_EPSILON_ARC_FILTER };
     35 
     36 // See nlp/fst/lib/shortest-distance.h for the template options class
     37 // that this one shadows
     38 struct ShortestDistanceOptions {
     39   const QueueType queue_type;
     40   const ArcFilterType arc_filter_type;
     41   const int64 source;
     42   const float delta;
     43   const bool first_path;
     44 
     45   ShortestDistanceOptions(QueueType qt, ArcFilterType aft, int64 s,
     46                           float d)
     47       : queue_type(qt), arc_filter_type(aft), source(s), delta(d),
     48         first_path(false) { }
     49 };
     50 
     51 
     52 
     53 // 1
     54 typedef args::Package<const FstClass &, vector<WeightClass> *,
     55                       const ShortestDistanceOptions &> ShortestDistanceArgs1;
     56 
     57 template<class Queue, class Arc, class ArcFilter>
     58 struct QueueConstructor {
     59   //  template<class Arc, class ArcFilter>
     60   static Queue *Construct(const Fst<Arc> &,
     61                           const vector<typename Arc::Weight> *) {
     62     return new Queue();
     63   }
     64 };
     65 
     66 // Specializations to deal with AutoQueue, NaturalShortestFirstQueue,
     67 // and TopOrderQueue's different constructors
     68 template<class Arc, class ArcFilter>
     69 struct QueueConstructor<AutoQueue<typename Arc::StateId>, Arc, ArcFilter> {
     70   //  template<class Arc, class ArcFilter>
     71   static AutoQueue<typename Arc::StateId> *Construct(
     72       const Fst<Arc> &fst,
     73       const vector<typename Arc::Weight> *distance) {
     74     return new AutoQueue<typename Arc::StateId>(fst, distance, ArcFilter());
     75   }
     76 };
     77 
     78 template<class Arc, class ArcFilter>
     79 struct QueueConstructor<NaturalShortestFirstQueue<typename Arc::StateId,
     80                                                   typename Arc::Weight>,
     81                         Arc, ArcFilter> {
     82   //  template<class Arc, class ArcFilter>
     83   static NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>
     84   *Construct(const Fst<Arc> &fst,
     85             const vector<typename Arc::Weight> *distance) {
     86     return new NaturalShortestFirstQueue<typename Arc::StateId,
     87                                          typename Arc::Weight>(*distance);
     88   }
     89 };
     90 
     91 template<class Arc, class ArcFilter>
     92 struct QueueConstructor<TopOrderQueue<typename Arc::StateId>, Arc, ArcFilter> {
     93   //  template<class Arc, class ArcFilter>
     94   static TopOrderQueue<typename Arc::StateId> *Construct(
     95       const Fst<Arc> &fst, const vector<typename Arc::Weight> *weights) {
     96     return new TopOrderQueue<typename Arc::StateId>(fst, ArcFilter());
     97   }
     98 };
     99 
    100 
    101 template<class Arc, class Queue>
    102 void ShortestDistanceHelper(ShortestDistanceArgs1 *args) {
    103   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    104   const ShortestDistanceOptions &opts = args->arg3;
    105 
    106   vector<typename Arc::Weight> weights;
    107 
    108   switch (opts.arc_filter_type) {
    109     case ANY_ARC_FILTER: {
    110       Queue *queue =
    111           QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
    112               fst, &weights);
    113       fst::ShortestDistanceOptions<Arc, Queue, AnyArcFilter<Arc> > sdopts(
    114           queue, AnyArcFilter<Arc>(), opts.source, opts.delta);
    115       ShortestDistance(fst, &weights, sdopts);
    116       delete queue;
    117       break;
    118     }
    119     case EPSILON_ARC_FILTER: {
    120       Queue *queue =
    121           QueueConstructor<Queue, Arc, AnyArcFilter<Arc> >::Construct(
    122               fst, &weights);
    123       fst::ShortestDistanceOptions<Arc, Queue,
    124           EpsilonArcFilter<Arc> > sdopts(
    125               queue, EpsilonArcFilter<Arc>(), opts.source, opts.delta);
    126       ShortestDistance(fst, &weights, sdopts);
    127       delete queue;
    128       break;
    129     }
    130     case INPUT_EPSILON_ARC_FILTER: {
    131       Queue *queue =
    132           QueueConstructor<Queue, Arc, InputEpsilonArcFilter<Arc> >::Construct(
    133               fst, &weights);
    134       fst::ShortestDistanceOptions<Arc, Queue,
    135           InputEpsilonArcFilter<Arc> > sdopts(
    136               queue, InputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
    137       ShortestDistance(fst, &weights, sdopts);
    138       delete queue;
    139       break;
    140     }
    141     case OUTPUT_EPSILON_ARC_FILTER: {
    142       Queue *queue =
    143           QueueConstructor<Queue, Arc,
    144           OutputEpsilonArcFilter<Arc> >::Construct(
    145               fst, &weights);
    146       fst::ShortestDistanceOptions<Arc, Queue,
    147           OutputEpsilonArcFilter<Arc> > sdopts(
    148               queue, OutputEpsilonArcFilter<Arc>(), opts.source, opts.delta);
    149       ShortestDistance(fst, &weights, sdopts);
    150       delete queue;
    151       break;
    152     }
    153   }
    154 
    155   // Copy the weights back
    156   args->arg2->resize(weights.size());
    157   for (unsigned i = 0; i < weights.size(); ++i) {
    158     (*args->arg2)[i] = WeightClass(weights[i]);
    159   }
    160 }
    161 
    162 template<class Arc>
    163 void ShortestDistance(ShortestDistanceArgs1 *args) {
    164   const ShortestDistanceOptions &opts = args->arg3;
    165   typedef typename Arc::StateId StateId;
    166   typedef typename Arc::Weight Weight;
    167 
    168   // Must consider (opts.queue_type x opts.filter_type) options
    169   switch (opts.queue_type) {
    170     default:
    171       FSTERROR() << "Unknown queue type." << opts.queue_type;
    172 
    173     case AUTO_QUEUE:
    174       ShortestDistanceHelper<Arc, AutoQueue<StateId> >(args);
    175       return;
    176 
    177     case FIFO_QUEUE:
    178       ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args);
    179       return;
    180 
    181     case LIFO_QUEUE:
    182       ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args);
    183       return;
    184 
    185     case SHORTEST_FIRST_QUEUE:
    186       ShortestDistanceHelper<Arc,
    187         NaturalShortestFirstQueue<StateId, Weight> >(args);
    188       return;
    189 
    190     case STATE_ORDER_QUEUE:
    191       ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args);
    192       return;
    193 
    194     case TOP_ORDER_QUEUE:
    195       ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args);
    196       return;
    197   }
    198 }
    199 
    200 // 2
    201 typedef args::Package<const FstClass&, vector<WeightClass>*,
    202                       bool, double> ShortestDistanceArgs2;
    203 
    204 template<class Arc>
    205 void ShortestDistance(ShortestDistanceArgs2 *args) {
    206   const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
    207   vector<typename Arc::Weight> distance;
    208 
    209   ShortestDistance(fst, &distance, args->arg3, args->arg4);
    210 
    211   // convert the typed weights back into weightclass
    212   vector<WeightClass> *retval = args->arg2;
    213   retval->resize(distance.size());
    214 
    215   for (unsigned i = 0; i < distance.size(); ++i) {
    216     (*retval)[i] = WeightClass(distance[i]);
    217   }
    218 }
    219 
    220 // 3
    221 typedef args::WithReturnValue<WeightClass,
    222                               const FstClass &> ShortestDistanceArgs3;
    223 
    224 template<class Arc>
    225 void ShortestDistance(ShortestDistanceArgs3 *args) {
    226   const Fst<Arc> &fst = *(args->args.GetFst<Arc>());
    227 
    228   args->retval = WeightClass(ShortestDistance(fst));
    229 }
    230 
    231 
    232 // 1
    233 void ShortestDistance(const FstClass &fst, vector<WeightClass> *distance,
    234                       const ShortestDistanceOptions &opts);
    235 
    236 // 2
    237 void ShortestDistance(const FstClass &ifst, vector<WeightClass> *distance,
    238                       bool reverse = false, double delta = fst::kDelta);
    239 
    240 #ifndef SWIG
    241 // 3
    242 WeightClass ShortestDistance(const FstClass &ifst);
    243 #endif
    244 
    245 }  // namespace script
    246 }  // namespace fst
    247 
    248 
    249 
    250 #endif  // FST_SCRIPT_SHORTEST_DISTANCE_H_
    251