Home | History | Annotate | Download | only in pdt
      1 // replace.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Recursively replace Fst arcs with other Fst(s) returning a PDT.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_REPLACE_H__
     22 #define FST_EXTENSIONS_PDT_REPLACE_H__
     23 
     24 #include <tr1/unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 
     28 #include <fst/replace.h>
     29 
     30 namespace fst {
     31 
     32 // Hash to paren IDs
     33 template <typename S>
     34 struct ReplaceParenHash {
     35   size_t operator()(const pair<size_t, S> &p) const {
     36     return p.first + p.second * kPrime;
     37   }
     38  private:
     39   static const size_t kPrime = 7853;
     40 };
     41 
     42 template <typename S> const size_t ReplaceParenHash<S>::kPrime;
     43 
     44 // Builds a pushdown transducer (PDT) from an RTN specification
     45 // identical to that in fst/lib/replace.h. The result is a PDT
     46 // encoded as the FST 'ofst' where some transitions are labeled with
     47 // open or close parentheses. To be interpreted as a PDT, the parens
     48 // must balance on a path (see PdtExpand()). The open/close
     49 // parenthesis label pairs are returned in 'parens'.
     50 template <class Arc>
     51 void Replace(const vector<pair<typename Arc::Label,
     52              const Fst<Arc>* > >& ifst_array,
     53              MutableFst<Arc> *ofst,
     54              vector<pair<typename Arc::Label,
     55              typename Arc::Label> > *parens,
     56              typename Arc::Label root) {
     57   typedef typename Arc::Label Label;
     58   typedef typename Arc::StateId StateId;
     59   typedef typename Arc::Weight Weight;
     60 
     61   ofst->DeleteStates();
     62   parens->clear();
     63 
     64   unordered_map<Label, size_t> label2id;
     65   for (size_t i = 0; i < ifst_array.size(); ++i)
     66     label2id[ifst_array[i].first] = i;
     67 
     68   Label max_label = kNoLabel;
     69   size_t max_non_term_count = 0;
     70 
     71   // Queue of non-terminals to replace
     72   deque<size_t> non_term_queue;
     73   // Map of non-terminals to replace to count
     74   unordered_map<Label, size_t> non_term_map;
     75   non_term_queue.push_back(root);
     76   non_term_map[root] = 1;;
     77 
     78   // PDT state corr. to ith replace FST start state.
     79   vector<StateId> fst_start(ifst_array.size(), kNoLabel);
     80   // PDT state, weight pairs corr. to ith replace FST final state & weights.
     81   vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
     82 
     83   // Builds single Fst combining all referenced input Fsts. Leaves in the
     84   // non-termnals for now.  Tabulate the PDT states that correspond to
     85   // the start and final states of the input Fsts.
     86   for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
     87     Label label = non_term_queue.front();
     88     non_term_queue.pop_front();
     89     size_t fst_id = label2id[label];
     90 
     91     const Fst<Arc> *ifst = ifst_array[fst_id].second;
     92     for (StateIterator< Fst<Arc> > siter(*ifst);
     93          !siter.Done(); siter.Next()) {
     94       StateId is = siter.Value();
     95       StateId os = ofst->AddState();
     96       if (is == ifst->Start()) {
     97         fst_start[fst_id] = os;
     98         if (label == root)
     99           ofst->SetStart(os);
    100       }
    101       if (ifst->Final(is) != Weight::Zero()) {
    102         if (label == root)
    103           ofst->SetFinal(os, ifst->Final(is));
    104         fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
    105       }
    106       for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
    107            !aiter.Done(); aiter.Next()) {
    108         Arc arc = aiter.Value();
    109         if (max_label == kNoLabel || arc.olabel > max_label)
    110           max_label = arc.olabel;
    111         typename unordered_map<Label, size_t>::const_iterator it =
    112             label2id.find(arc.olabel);
    113         if (it != label2id.end()) {
    114           size_t nfst_id = it->second;
    115           if (ifst_array[nfst_id].second->Start() == -1)
    116             continue;
    117           size_t count = non_term_map[arc.olabel]++;
    118           if (count == 0)
    119             non_term_queue.push_back(arc.olabel);
    120           if (count > max_non_term_count)
    121             max_non_term_count = count;
    122         }
    123         arc.nextstate += soff;
    124         ofst->AddArc(os, arc);
    125       }
    126     }
    127   }
    128 
    129   // Changes each non-terminal transition to an open parenthesis
    130   // transition redirected to the PDT state that corresponds to the
    131   // start state of the input FST for the non-terminal. Adds close parenthesis
    132   // transitions from the PDT states corr. to the final states of the
    133   // input FST for the non-terminal to the former destination state of the
    134   // non-terminal transition.
    135 
    136   typedef MutableArcIterator< MutableFst<Arc> > MIter;
    137   typedef unordered_map<pair<size_t, StateId >, size_t,
    138                    ReplaceParenHash<StateId> > ParenMap;
    139 
    140   // Parenthesis pair ID per fst, state pair.
    141   ParenMap paren_map;
    142   // # of parenthesis pairs per fst.
    143   vector<size_t> nparens(ifst_array.size(), 0);
    144   // Initial open parenthesis label
    145   Label first_open_paren = max_label + 1;
    146   Label first_close_paren = max_label + max_non_term_count + 1;
    147 
    148   for (StateIterator< Fst<Arc> > siter(*ofst);
    149        !siter.Done(); siter.Next()) {
    150     StateId os = siter.Value();
    151     MIter *aiter = new MIter(ofst, os);
    152     for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
    153       Arc arc = aiter->Value();
    154       typename unordered_map<Label, size_t>::const_iterator lit =
    155           label2id.find(arc.olabel);
    156       if (lit != label2id.end()) {
    157         size_t nfst_id = lit->second;
    158 
    159         // Get parentheses. Ensures distinct parenthesis pair per
    160         // non-terminal and destination state but otherwise reuses them.
    161         Label open_paren = kNoLabel, close_paren = kNoLabel;
    162         pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
    163         typename ParenMap::const_iterator pit = paren_map.find(paren_key);
    164         if (pit != paren_map.end()) {
    165           size_t paren_id = pit->second;
    166           open_paren = (*parens)[paren_id].first;
    167           close_paren = (*parens)[paren_id].second;
    168         } else {
    169           size_t paren_id = nparens[nfst_id]++;
    170           open_paren = first_open_paren + paren_id;
    171           close_paren = first_close_paren + paren_id;
    172           paren_map[paren_key] = paren_id;
    173           if (paren_id >= parens->size())
    174             parens->push_back(make_pair(open_paren, close_paren));
    175         }
    176 
    177         // Sets open parenthesis.
    178         Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
    179         aiter->SetValue(sarc);
    180 
    181         // Adds close parentheses.
    182         for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
    183           pair<StateId, Weight> &p = fst_final[nfst_id][i];
    184           Arc farc(close_paren, close_paren, p.second, arc.nextstate);
    185 
    186           ofst->AddArc(p.first, farc);
    187           if (os == p.first) {  // Invalidated iterator
    188             delete aiter;
    189             aiter = new MIter(ofst, os);
    190             aiter->Seek(n);
    191           }
    192         }
    193       }
    194     }
    195     delete aiter;
    196   }
    197 }
    198 
    199 }  // namespace fst
    200 
    201 #endif  // FST_EXTENSIONS_PDT_REPLACE_H__
    202