Home | History | Annotate | Download | only in script
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Author: jpr (at) google.com (Jake Ratkiewicz)
     16 
     17 #ifndef FST_SCRIPT_SHORTEST_PATH_H_
     18 #define FST_SCRIPT_SHORTEST_PATH_H_
     19 
     20 #include <vector>
     21 using std::vector;
     22 
     23 #include <fst/script/arg-packs.h>
     24 #include <fst/script/fst-class.h>
     25 #include <fst/script/weight-class.h>
     26 #include <fst/shortest-path.h>
     27 #include <fst/script/shortest-distance.h>  // for ShortestDistanceOptions
     28 
     29 namespace fst {
     30 namespace script {
     31 
     32 struct ShortestPathOptions
     33     : public fst::script::ShortestDistanceOptions {
     34   const size_t nshortest;
     35   const bool unique;
     36   const bool has_distance;
     37   const bool first_path;
     38   const WeightClass weight_threshold;
     39   const int64 state_threshold;
     40 
     41   ShortestPathOptions(QueueType qt, size_t n = 1,
     42                       bool u = false, bool hasdist = false,
     43                       float d = fst::kDelta, bool fp = false,
     44                       WeightClass w = fst::script::WeightClass::Zero(),
     45                       int64 s = fst::kNoStateId)
     46       : ShortestDistanceOptions(qt, ANY_ARC_FILTER, kNoStateId, d),
     47         nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
     48         weight_threshold(w), state_threshold(s) { }
     49 };
     50 
     51 typedef args::Package<const FstClass &, MutableFstClass *,
     52                       vector<WeightClass> *, const ShortestPathOptions &>
     53   ShortestPathArgs1;
     54 
     55 
     56 template<class Arc>
     57 void ShortestPath(ShortestPathArgs1 *args) {
     58   const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>());
     59   MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
     60   const ShortestPathOptions &opts = args->arg4;
     61   typedef typename Arc::StateId StateId;
     62   typedef typename Arc::Weight Weight;
     63   typedef AnyArcFilter<Arc> ArcFilter;
     64 
     65   vector<typename Arc::Weight> weights;
     66   typename Arc::Weight weight_threshold =
     67       *(opts.weight_threshold.GetWeight<Weight>());
     68 
     69   switch (opts.queue_type) {
     70     case AUTO_QUEUE: {
     71       typedef AutoQueue<StateId> Queue;
     72       Queue *queue = QueueConstructor<Queue, Arc,
     73           ArcFilter>::Construct(ifst, &weights);
     74       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
     75           queue, ArcFilter(), opts.nshortest, opts.unique,
     76           opts.has_distance, opts.delta, opts.first_path,
     77           weight_threshold, opts.state_threshold);
     78       ShortestPath(ifst, ofst, &weights, spopts);
     79       delete queue;
     80       return;
     81     }
     82     case FIFO_QUEUE: {
     83       typedef FifoQueue<StateId> Queue;
     84       Queue *queue = QueueConstructor<Queue, Arc,
     85           ArcFilter>::Construct(ifst, &weights);
     86       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
     87           queue, ArcFilter(), opts.nshortest, opts.unique,
     88           opts.has_distance, opts.delta, opts.first_path,
     89           weight_threshold, opts.state_threshold);
     90       ShortestPath(ifst, ofst, &weights, spopts);
     91       delete queue;
     92       return;
     93     }
     94     case LIFO_QUEUE: {
     95       typedef LifoQueue<StateId> Queue;
     96       Queue *queue = QueueConstructor<Queue, Arc,
     97           ArcFilter >::Construct(ifst, &weights);
     98       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
     99           queue, ArcFilter(), opts.nshortest, opts.unique,
    100           opts.has_distance, opts.delta, opts.first_path,
    101           weight_threshold, opts.state_threshold);
    102       ShortestPath(ifst, ofst, &weights, spopts);
    103       delete queue;
    104       return;
    105     }
    106     case SHORTEST_FIRST_QUEUE: {
    107       typedef NaturalShortestFirstQueue<StateId, Weight> Queue;
    108       Queue *queue = QueueConstructor<Queue, Arc,
    109           ArcFilter>::Construct(ifst, &weights);
    110       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
    111           queue, ArcFilter(), opts.nshortest, opts.unique,
    112           opts.has_distance, opts.delta, opts.first_path,
    113           weight_threshold, opts.state_threshold);
    114       ShortestPath(ifst, ofst, &weights, spopts);
    115       delete queue;
    116       return;
    117     }
    118     case STATE_ORDER_QUEUE: {
    119       typedef StateOrderQueue<StateId> Queue;
    120       Queue *queue = QueueConstructor<Queue, Arc,
    121           ArcFilter>::Construct(ifst, &weights);
    122       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
    123           queue, ArcFilter(), opts.nshortest, opts.unique,
    124           opts.has_distance, opts.delta, opts.first_path,
    125           weight_threshold, opts.state_threshold);
    126       ShortestPath(ifst, ofst, &weights, spopts);
    127       delete queue;
    128       return;
    129     }
    130     case TOP_ORDER_QUEUE: {
    131       typedef TopOrderQueue<StateId> Queue;
    132       Queue *queue = QueueConstructor<Queue, Arc,
    133           ArcFilter>::Construct(ifst, &weights);
    134       fst::ShortestPathOptions<Arc, Queue, ArcFilter> spopts(
    135           queue, ArcFilter(), opts.nshortest, opts.unique,
    136           opts.has_distance, opts.delta, opts.first_path,
    137           weight_threshold, opts.state_threshold);
    138       ShortestPath(ifst, ofst, &weights, spopts);
    139       delete queue;
    140       return;
    141     }
    142     default:
    143       FSTERROR() << "Unknown queue type: " << opts.queue_type;
    144       ofst->SetProperties(kError, kError);
    145   }
    146 
    147   // Copy the weights back
    148   args->arg3->resize(weights.size());
    149   for (unsigned i = 0; i < weights.size(); ++i) {
    150     (*args->arg3)[i] = WeightClass(weights[i]);
    151   }
    152 }
    153 
    154 // 2
    155 typedef args::Package<const FstClass &, MutableFstClass *,
    156                       size_t, bool, bool, WeightClass,
    157                       int64> ShortestPathArgs2;
    158 
    159 template<class Arc>
    160 void ShortestPath(ShortestPathArgs2 *args) {
    161   const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>());
    162   MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
    163   typename Arc::Weight weight_threshold =
    164       *(args->arg6.GetWeight<typename Arc::Weight>());
    165 
    166   ShortestPath(ifst, ofst, args->arg3, args->arg4, args->arg5,
    167                weight_threshold, args->arg7);
    168 }
    169 
    170 
    171 // 1
    172 void ShortestPath(const FstClass &ifst, MutableFstClass *ofst,
    173                   vector<WeightClass> *distance,
    174                   const ShortestPathOptions &opts);
    175 
    176 
    177 // 2
    178 void ShortestPath(const FstClass &ifst, MutableFstClass *ofst,
    179                   size_t n = 1, bool unique = false,
    180                   bool first_path = false,
    181                   WeightClass weight_threshold =
    182                     fst::script::WeightClass::Zero(),
    183                   int64 state_threshold = fst::kNoStateId);
    184 
    185 }  // namespace script
    186 }  // namespace fst
    187 
    188 
    189 
    190 #endif  // FST_SCRIPT_SHORTEST_PATH_H_
    191