Home | History | Annotate | Download | only in pdt
      1 // pdt.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Common classes for PDT expansion/traversal.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_PDT_H__
     22 #define FST_EXTENSIONS_PDT_PDT_H__
     23 
     24 #include <unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <map>
     28 #include <set>
     29 
     30 #include <fst/state-table.h>
     31 #include <fst/fst.h>
     32 
     33 namespace fst {
     34 
     35 // Provides bijection between parenthesis stacks and signed integral
     36 // stack IDs. Each stack ID is unique to each distinct stack.  The
     37 // open-close parenthesis label pairs are passed in 'parens'.
     38 template <typename K, typename L>
     39 class PdtStack {
     40  public:
     41   typedef K StackId;
     42   typedef L Label;
     43 
     44   // The stacks are stored in a tree. The nodes are stored in vector
     45   // 'nodes_'. Each node represents the top of some stack and is
     46   // ID'ed by its position in the vector. Its parent node represents
     47   // the stack with the top 'popped' and its children are stored in
     48   // 'child_map_' accessed by stack_id and label. The paren_id is
     49   // the position in 'parens' of the parenthesis for that node.
     50   struct StackNode {
     51     StackId parent_id;
     52     size_t paren_id;
     53 
     54     StackNode(StackId p, size_t i) : parent_id(p), paren_id(i) {}
     55   };
     56 
     57   PdtStack(const vector<pair<Label, Label> > &parens)
     58       : parens_(parens), min_paren_(kNoLabel), max_paren_(kNoLabel) {
     59     for (size_t i = 0; i < parens.size(); ++i) {
     60       const pair<Label, Label>  &p = parens[i];
     61       paren_map_[p.first] = i;
     62       paren_map_[p.second] = i;
     63 
     64       if (min_paren_ == kNoLabel || p.first < min_paren_)
     65         min_paren_ = p.first;
     66       if (p.second < min_paren_)
     67         min_paren_ = p.second;
     68 
     69       if (max_paren_ == kNoLabel || p.first > max_paren_)
     70         max_paren_ = p.first;
     71       if (p.second > max_paren_)
     72         max_paren_ = p.second;
     73     }
     74     nodes_.push_back(StackNode(-1, -1));  // Tree root.
     75   }
     76 
     77   // Returns stack ID given the current stack ID (0 if empty) and
     78   // label read. 'Pushes' onto a stack if the label is an open
     79   // parenthesis, returning the new stack ID. 'Pops' the stack if the
     80   // label is a close parenthesis that matches the top of the stack,
     81   // returning the parent stack ID. Returns -1 if label is an
     82   // unmatched close parenthesis. Otherwise, returns the current stack
     83   // ID.
     84   StackId Find(StackId stack_id, Label label) {
     85     if (min_paren_ == kNoLabel || label < min_paren_ || label > max_paren_)
     86       return stack_id;                       // Non-paren.
     87 
     88     typename unordered_map<Label, size_t>::const_iterator pit
     89         = paren_map_.find(label);
     90     if (pit == paren_map_.end())             // Non-paren.
     91       return stack_id;
     92     ssize_t paren_id = pit->second;
     93 
     94     if (label == parens_[paren_id].first) {  // Open paren.
     95       StackId &child_id = child_map_[make_pair(stack_id, label)];
     96       if (child_id == 0) {                   // Child not found, push label.
     97         child_id = nodes_.size();
     98         nodes_.push_back(StackNode(stack_id, paren_id));
     99       }
    100       return child_id;
    101     }
    102 
    103     const StackNode &node = nodes_[stack_id];
    104     if (paren_id == node.paren_id)           // Matching close paren.
    105       return node.parent_id;
    106 
    107     return -1;                               // Non-matching close paren.
    108   }
    109 
    110   // Returns the stack ID obtained by "popping" the label at the top
    111   // of the current stack ID.
    112   StackId Pop(StackId stack_id) const {
    113     return nodes_[stack_id].parent_id;
    114   }
    115 
    116   // Returns the paren ID at the top of the stack for 'stack_id'
    117   ssize_t Top(StackId stack_id) const {
    118     return nodes_[stack_id].paren_id;
    119   }
    120 
    121   ssize_t ParenId(Label label) const {
    122     typename unordered_map<Label, size_t>::const_iterator pit
    123         = paren_map_.find(label);
    124     if (pit == paren_map_.end())  // Non-paren.
    125       return -1;
    126     return pit->second;
    127   }
    128 
    129  private:
    130   struct ChildHash {
    131     size_t operator()(const pair<StackId, Label> &p) const {
    132       return p.first + p.second * kPrime;
    133     }
    134   };
    135 
    136   static const size_t kPrime;
    137 
    138   vector<pair<Label, Label> > parens_;
    139   vector<StackNode> nodes_;
    140   unordered_map<Label, size_t> paren_map_;
    141   unordered_map<pair<StackId, Label>,
    142            StackId, ChildHash> child_map_;   // Child of stack node wrt label
    143   Label min_paren_;                          // For faster paren. check
    144   Label max_paren_;                          // For faster paren. check
    145 };
    146 
    147 template <typename T, typename L>
    148 const size_t PdtStack<T, L>::kPrime = 7853;
    149 
    150 
    151 // State tuple for PDT expansion
    152 template <typename S, typename K>
    153 struct PdtStateTuple {
    154   typedef S StateId;
    155   typedef K StackId;
    156 
    157   StateId state_id;
    158   StackId stack_id;
    159 
    160   PdtStateTuple()
    161       : state_id(kNoStateId), stack_id(-1) {}
    162 
    163   PdtStateTuple(StateId fs, StackId ss)
    164       : state_id(fs), stack_id(ss) {}
    165 };
    166 
    167 // Equality of PDT state tuples.
    168 template <typename S, typename K>
    169 inline bool operator==(const PdtStateTuple<S, K>& x,
    170                        const PdtStateTuple<S, K>& y) {
    171   if (&x == &y)
    172     return true;
    173   return x.state_id == y.state_id && x.stack_id == y.stack_id;
    174 }
    175 
    176 
    177 // Hash function object for PDT state tuples
    178 template <class T>
    179 class PdtStateHash {
    180  public:
    181   size_t operator()(const T &tuple) const {
    182     return tuple.state_id + tuple.stack_id * kPrime;
    183   }
    184 
    185  private:
    186   static const size_t kPrime;
    187 };
    188 
    189 template <typename T>
    190 const size_t PdtStateHash<T>::kPrime = 7853;
    191 
    192 
    193 // Tuple to PDT state bijection.
    194 template <class S, class K>
    195 class PdtStateTable
    196     : public CompactHashStateTable<PdtStateTuple<S, K>,
    197                                    PdtStateHash<PdtStateTuple<S, K> > > {
    198  public:
    199   typedef S StateId;
    200   typedef K StackId;
    201 
    202   PdtStateTable() {}
    203 
    204   PdtStateTable(const PdtStateTable<S, K> &table) {}
    205 
    206  private:
    207   void operator=(const PdtStateTable<S, K> &table);  // disallow
    208 };
    209 
    210 }  // namespace fst
    211 
    212 #endif  // FST_EXTENSIONS_PDT_PDT_H__
    213