Home | History | Annotate | Download | only in fst
      1 // sparse-tuple-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: krr (at) google.com (Kasturi Rangan Raghavan)
     17 // Inspiration: allauzen (at) google.com (Cyril Allauzen)
     18 // \file
     19 // Sparse version of tuple-weight, based on tuple-weight.h
     20 //   Internally stores sparse key, value pairs in linked list
     21 //   Default value elemnt is the assumed value of unset keys
     22 //   Internal singleton implementation that stores first key,
     23 //   value pair as a initialized member variable to avoide
     24 //   unnecessary allocation on heap.
     25 // Use SparseTupleWeightIterator to iterate through the key,value pairs
     26 // Note: this does NOT iterate through the default value.
     27 //
     28 // Sparse tuple weight set operation definitions.
     29 
     30 #ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__
     31 #define FST_LIB_SPARSE_TUPLE_WEIGHT_H__
     32 
     33 #include<string>
     34 #include<list>
     35 #include<stack>
     36 #include<unordered_map>
     37 using std::tr1::unordered_map;
     38 using std::tr1::unordered_multimap;
     39 
     40 #include <fst/weight.h>
     41 
     42 
     43 DECLARE_string(fst_weight_parentheses);
     44 DECLARE_string(fst_weight_separator);
     45 
     46 namespace fst {
     47 
     48 template <class W, class K> class SparseTupleWeight;
     49 
     50 template<class W, class K>
     51 class SparseTupleWeightIterator;
     52 
     53 template <class W, class K>
     54 istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w);
     55 
     56 // Arbitrary dimension tuple weight, stored as a sorted linked-list
     57 // W is any weight class,
     58 // K is the key value type. kNoKey(-1) is reserved for internal use
     59 template <class W, class K = int>
     60 class SparseTupleWeight {
     61  public:
     62   typedef pair<K, W> Pair;
     63   typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight;
     64 
     65   const static K kNoKey = -1;
     66   SparseTupleWeight() {
     67     Init();
     68   }
     69 
     70   template <class Iterator>
     71   SparseTupleWeight(Iterator begin, Iterator end) {
     72     Init();
     73     // Assumes input iterator is sorted
     74     for (Iterator it = begin; it != end; ++it)
     75       Push(*it);
     76   }
     77 
     78 
     79   SparseTupleWeight(const K& key, const W &w) {
     80     Init();
     81     Push(key, w);
     82   }
     83 
     84   SparseTupleWeight(const W &w) {
     85     Init(w);
     86   }
     87 
     88   SparseTupleWeight(const SparseTupleWeight<W, K> &w) {
     89     Init(w.DefaultValue());
     90     SetDefaultValue(w.DefaultValue());
     91     for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
     92       Push(it.Value());
     93     }
     94   }
     95 
     96   static const SparseTupleWeight<W, K> &Zero() {
     97     static SparseTupleWeight<W, K> zero;
     98     return zero;
     99   }
    100 
    101   static const SparseTupleWeight<W, K> &One() {
    102     static SparseTupleWeight<W, K> one(W::One());
    103     return one;
    104   }
    105 
    106   static const SparseTupleWeight<W, K> &NoWeight() {
    107     static SparseTupleWeight<W, K> no_weight(W::NoWeight());
    108     return no_weight;
    109   }
    110 
    111   istream &Read(istream &strm) {
    112     ReadType(strm, &default_);
    113     ReadType(strm, &first_);
    114     return ReadType(strm, &rest_);
    115   }
    116 
    117   ostream &Write(ostream &strm) const {
    118     WriteType(strm, default_);
    119     WriteType(strm, first_);
    120     return WriteType(strm, rest_);
    121   }
    122 
    123   SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) {
    124     if (this == &w) return *this; // check for w = w
    125     Init(w.DefaultValue());
    126     for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
    127       Push(it.Value());
    128     }
    129     return *this;
    130   }
    131 
    132   bool Member() const {
    133     if (!DefaultValue().Member()) return false;
    134     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
    135       if (!it.Value().second.Member()) return false;
    136     }
    137     return true;
    138   }
    139 
    140   // Assumes H() function exists for the hash of the key value
    141   size_t Hash() const {
    142     uint64 h = 0;
    143     std::hash<K> H;
    144     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
    145       h = 5 * h + H(it.Value().first);
    146       h = 13 * h + it.Value().second.Hash();
    147     }
    148     return size_t(h);
    149   }
    150 
    151   SparseTupleWeight<W, K> Quantize(float delta = kDelta) const {
    152     SparseTupleWeight<W, K> w;
    153     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
    154       w.Push(it.Value().first, it.Value().second.Quantize(delta));
    155     }
    156     return w;
    157   }
    158 
    159   ReverseWeight Reverse() const {
    160     SparseTupleWeight<W, K> w;
    161     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
    162       w.Push(it.Value().first, it.Value().second.Reverse());
    163     }
    164     return w;
    165   }
    166 
    167   // Common initializer among constructors.
    168   void Init() {
    169     Init(W::Zero());
    170   }
    171 
    172   void Init(const W& default_value) {
    173     first_.first = kNoKey;
    174     /* initialized to the reserved key value */
    175     default_ = default_value;
    176     rest_.clear();
    177   }
    178 
    179   size_t Size() const {
    180     if (first_.first == kNoKey)
    181       return 0;
    182     else
    183       return  rest_.size() + 1;
    184   }
    185 
    186   inline void Push(const K &k, const W &w, bool default_value_check = true) {
    187     Push(make_pair(k, w), default_value_check);
    188   }
    189 
    190   inline void Push(const Pair &p, bool default_value_check = true) {
    191     if (default_value_check && p.second == default_) return;
    192     if (first_.first == kNoKey) {
    193       first_ = p;
    194     } else {
    195       rest_.push_back(p);
    196     }
    197   }
    198 
    199   void SetDefaultValue(const W& val) { default_ = val; }
    200 
    201   const W& DefaultValue() const { return default_; }
    202 
    203  protected:
    204   static istream& ReadNoParen(
    205     istream&, SparseTupleWeight<W, K>&, char separator);
    206 
    207   static istream& ReadWithParen(
    208     istream&, SparseTupleWeight<W, K>&,
    209     char separator, char open_paren, char close_paren);
    210 
    211  private:
    212   // Assumed default value of uninitialized keys, by default W::Zero()
    213   W default_;
    214 
    215   // Key values pairs are first stored in first_, then fill rest_
    216   // this way we can avoid dynamic allocation in the common case
    217   // where the weight is a single key,val pair.
    218   Pair first_;
    219   list<Pair> rest_;
    220 
    221   friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&);
    222   friend class SparseTupleWeightIterator<W, K>;
    223 };
    224 
    225 template<class W, class K>
    226 class SparseTupleWeightIterator {
    227  public:
    228   typedef typename SparseTupleWeight<W, K>::Pair Pair;
    229   typedef typename list<Pair>::const_iterator const_iterator;
    230   typedef typename list<Pair>::iterator iterator;
    231 
    232   explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w)
    233     : first_(w.first_), rest_(w.rest_), init_(true),
    234       iter_(rest_.begin()) {}
    235 
    236   bool Done() const {
    237     if (init_)
    238       return first_.first == SparseTupleWeight<W, K>::kNoKey;
    239     else
    240       return iter_ == rest_.end();
    241   }
    242 
    243   const Pair& Value() const { return init_ ? first_ : *iter_; }
    244 
    245   void Next() {
    246     if (init_)
    247       init_ = false;
    248     else
    249       ++iter_;
    250   }
    251 
    252   void Reset() {
    253     init_ = true;
    254     iter_ = rest_.begin();
    255   }
    256 
    257  private:
    258   const Pair &first_;
    259   const list<Pair> & rest_;
    260   bool init_;  // in the initialized state?
    261   typename list<Pair>::const_iterator iter_;
    262 
    263   DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator);
    264 };
    265 
    266 template<class W, class K, class M>
    267 inline void SparseTupleWeightMap(
    268   SparseTupleWeight<W, K>* ret,
    269   const SparseTupleWeight<W, K>& w1,
    270   const SparseTupleWeight<W, K>& w2,
    271   const M& operator_mapper) {
    272   SparseTupleWeightIterator<W, K> w1_it(w1);
    273   SparseTupleWeightIterator<W, K> w2_it(w2);
    274   const W& v1_def = w1.DefaultValue();
    275   const W& v2_def = w2.DefaultValue();
    276   ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def));
    277   while (!w1_it.Done() || !w2_it.Done()) {
    278     const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
    279     const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
    280     const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
    281     const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
    282     if (k1 == k2) {
    283       ret->Push(k1, operator_mapper.Map(k1, v1, v2));
    284       if (!w1_it.Done()) w1_it.Next();
    285       if (!w2_it.Done()) w2_it.Next();
    286     } else if (k1 < k2) {
    287       ret->Push(k1, operator_mapper.Map(k1, v1, v2_def));
    288       w1_it.Next();
    289     } else {
    290       ret->Push(k2, operator_mapper.Map(k2, v1_def, v2));
    291       w2_it.Next();
    292     }
    293   }
    294 }
    295 
    296 template <class W, class K>
    297 inline bool operator==(const SparseTupleWeight<W, K> &w1,
    298                        const SparseTupleWeight<W, K> &w2) {
    299   const W& v1_def = w1.DefaultValue();
    300   const W& v2_def = w2.DefaultValue();
    301   if (v1_def != v2_def) return false;
    302 
    303   SparseTupleWeightIterator<W, K> w1_it(w1);
    304   SparseTupleWeightIterator<W, K> w2_it(w2);
    305   while (!w1_it.Done() || !w2_it.Done()) {
    306     const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
    307     const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
    308     const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
    309     const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
    310     if (k1 == k2) {
    311       if (v1 != v2) return false;
    312       if (!w1_it.Done()) w1_it.Next();
    313       if (!w2_it.Done()) w2_it.Next();
    314     } else if (k1 < k2) {
    315       if (v1 != v2_def) return false;
    316       w1_it.Next();
    317     } else {
    318       if (v1_def != v2) return false;
    319       w2_it.Next();
    320     }
    321   }
    322   return true;
    323 }
    324 
    325 template <class W, class K>
    326 inline bool operator!=(const SparseTupleWeight<W, K> &w1,
    327                        const SparseTupleWeight<W, K> &w2) {
    328   return !(w1 == w2);
    329 }
    330 
    331 template <class W, class K>
    332 inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) {
    333   if(FLAGS_fst_weight_separator.size() != 1) {
    334     FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
    335     strm.clear(std::ios::badbit);
    336     return strm;
    337   }
    338   char separator = FLAGS_fst_weight_separator[0];
    339   bool write_parens = false;
    340   if (!FLAGS_fst_weight_parentheses.empty()) {
    341     if (FLAGS_fst_weight_parentheses.size() != 2) {
    342       FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
    343       strm.clear(std::ios::badbit);
    344       return strm;
    345     }
    346     write_parens = true;
    347   }
    348 
    349   if (write_parens)
    350     strm << FLAGS_fst_weight_parentheses[0];
    351 
    352   strm << w.DefaultValue();
    353   strm << separator;
    354 
    355   size_t n = w.Size();
    356   strm << n;
    357   strm << separator;
    358 
    359   for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
    360       strm << it.Value().first;
    361       strm << separator;
    362       strm << it.Value().second;
    363       strm << separator;
    364   }
    365 
    366   if (write_parens)
    367     strm << FLAGS_fst_weight_parentheses[1];
    368 
    369   return strm;
    370 }
    371 
    372 template <class W, class K>
    373 inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) {
    374   if(FLAGS_fst_weight_separator.size() != 1) {
    375     FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
    376     strm.clear(std::ios::badbit);
    377     return strm;
    378   }
    379   char separator = FLAGS_fst_weight_separator[0];
    380 
    381   if (!FLAGS_fst_weight_parentheses.empty()) {
    382     if (FLAGS_fst_weight_parentheses.size() != 2) {
    383       FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
    384       strm.clear(std::ios::badbit);
    385       return strm;
    386     }
    387     return SparseTupleWeight<W, K>::ReadWithParen(
    388         strm, w, separator, FLAGS_fst_weight_parentheses[0],
    389         FLAGS_fst_weight_parentheses[1]);
    390   } else {
    391     return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator);
    392   }
    393 }
    394 
    395 // Reads SparseTupleWeight when there are no parentheses around tuple terms
    396 template <class W, class K>
    397 inline istream& SparseTupleWeight<W, K>::ReadNoParen(
    398     istream &strm,
    399     SparseTupleWeight<W, K> &w,
    400     char separator) {
    401   int c;
    402   size_t n;
    403 
    404   do {
    405     c = strm.get();
    406   } while (isspace(c));
    407 
    408 
    409   { // Read default weight
    410     W default_value;
    411     string s;
    412     while (c != separator) {
    413       if (c == EOF) {
    414         strm.clear(std::ios::badbit);
    415         return strm;
    416       }
    417       s += c;
    418       c = strm.get();
    419     }
    420     istringstream sstrm(s);
    421     sstrm >> default_value;
    422     w.SetDefaultValue(default_value);
    423   }
    424 
    425   c = strm.get();
    426 
    427   { // Read n
    428     string s;
    429     while (c != separator) {
    430       if (c == EOF) {
    431         strm.clear(std::ios::badbit);
    432         return strm;
    433       }
    434       s += c;
    435       c = strm.get();
    436     }
    437     istringstream sstrm(s);
    438     sstrm >> n;
    439   }
    440 
    441   // Read n elements
    442   for (size_t i = 0; i < n; ++i) {
    443     // discard separator
    444     c = strm.get();
    445     K p;
    446     W r;
    447 
    448     { // read key
    449       string s;
    450       while (c != separator) {
    451         if (c == EOF) {
    452           strm.clear(std::ios::badbit);
    453           return strm;
    454         }
    455         s += c;
    456         c = strm.get();
    457       }
    458       istringstream sstrm(s);
    459       sstrm >> p;
    460     }
    461 
    462     c = strm.get();
    463 
    464     { // read weight
    465       string s;
    466       while (c != separator) {
    467         if (c == EOF) {
    468           strm.clear(std::ios::badbit);
    469           return strm;
    470         }
    471         s += c;
    472         c = strm.get();
    473       }
    474       istringstream sstrm(s);
    475       sstrm >> r;
    476     }
    477 
    478     w.Push(p, r);
    479   }
    480 
    481   c = strm.get();
    482   if (c != separator) {
    483     strm.clear(std::ios::badbit);
    484   }
    485 
    486   return strm;
    487 }
    488 
    489 // Reads SparseTupleWeight when there are parentheses around tuple terms
    490 template <class W, class K>
    491 inline istream& SparseTupleWeight<W, K>::ReadWithParen(
    492     istream &strm,
    493     SparseTupleWeight<W, K> &w,
    494     char separator,
    495     char open_paren,
    496     char close_paren) {
    497   int c;
    498   size_t n;
    499 
    500   do {
    501     c = strm.get();
    502   } while (isspace(c));
    503 
    504   if (c != open_paren) {
    505     FSTERROR() << "is fst_weight_parentheses flag set correcty? ";
    506     strm.clear(std::ios::badbit);
    507     return strm;
    508   }
    509 
    510   c = strm.get();
    511 
    512   { // Read weight
    513     W default_value;
    514     stack<int> parens;
    515     string s;
    516     while (c != separator || !parens.empty()) {
    517       if (c == EOF) {
    518         strm.clear(std::ios::badbit);
    519         return strm;
    520       }
    521       s += c;
    522       // If parens encountered before separator, they must be matched
    523       if (c == open_paren) {
    524         parens.push(1);
    525       } else if (c == close_paren) {
    526         // Fail for mismatched parens
    527         if (parens.empty()) {
    528           strm.clear(std::ios::failbit);
    529           return strm;
    530         }
    531         parens.pop();
    532       }
    533       c = strm.get();
    534     }
    535     istringstream sstrm(s);
    536     sstrm >> default_value;
    537     w.SetDefaultValue(default_value);
    538   }
    539 
    540   c = strm.get();
    541 
    542   { // Read n
    543     string s;
    544     while (c != separator) {
    545       if (c == EOF) {
    546         strm.clear(std::ios::badbit);
    547         return strm;
    548       }
    549       s += c;
    550       c = strm.get();
    551     }
    552     istringstream sstrm(s);
    553     sstrm >> n;
    554   }
    555 
    556   // Read n elements
    557   for (size_t i = 0; i < n; ++i) {
    558     // discard separator
    559     c = strm.get();
    560     K p;
    561     W r;
    562 
    563     { // Read key
    564       stack<int> parens;
    565       string s;
    566       while (c != separator || !parens.empty()) {
    567         if (c == EOF) {
    568           strm.clear(std::ios::badbit);
    569           return strm;
    570         }
    571         s += c;
    572         // If parens encountered before separator, they must be matched
    573         if (c == open_paren) {
    574           parens.push(1);
    575         } else if (c == close_paren) {
    576           // Fail for mismatched parens
    577           if (parens.empty()) {
    578             strm.clear(std::ios::failbit);
    579             return strm;
    580           }
    581           parens.pop();
    582         }
    583         c = strm.get();
    584       }
    585       istringstream sstrm(s);
    586       sstrm >> p;
    587     }
    588 
    589     c = strm.get();
    590 
    591     { // Read weight
    592       stack<int> parens;
    593       string s;
    594       while (c != separator || !parens.empty()) {
    595         if (c == EOF) {
    596           strm.clear(std::ios::badbit);
    597           return strm;
    598         }
    599         s += c;
    600         // If parens encountered before separator, they must be matched
    601         if (c == open_paren) {
    602           parens.push(1);
    603         } else if (c == close_paren) {
    604           // Fail for mismatched parens
    605           if (parens.empty()) {
    606             strm.clear(std::ios::failbit);
    607             return strm;
    608           }
    609           parens.pop();
    610         }
    611         c = strm.get();
    612       }
    613       istringstream sstrm(s);
    614       sstrm >> r;
    615     }
    616 
    617     w.Push(p, r);
    618   }
    619 
    620   if (c != separator) {
    621     FSTERROR() << " separator expected, not found! ";
    622     strm.clear(std::ios::badbit);
    623     return strm;
    624   }
    625 
    626   c = strm.get();
    627   if (c != close_paren) {
    628     FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
    629     strm.clear(std::ios::badbit);
    630     return strm;
    631   }
    632 
    633   return strm;
    634 }
    635 
    636 
    637 
    638 }  // namespace fst
    639 
    640 #endif  // FST_LIB_SPARSE_TUPLE_WEIGHT_H__
    641