Home | History | Annotate | Download | only in fst
      1 // encode.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: johans (at) google.com (Johan Schalkwyk)
     17 //
     18 // \file
     19 // Class to encode and decoder an fst.
     20 
     21 #ifndef FST_LIB_ENCODE_H__
     22 #define FST_LIB_ENCODE_H__
     23 
     24 #include <climits>
     25 #include <unordered_map>
     26 using std::tr1::unordered_map;
     27 using std::tr1::unordered_multimap;
     28 #include <string>
     29 #include <vector>
     30 using std::vector;
     31 
     32 #include <fst/arc-map.h>
     33 #include <fst/rmfinalepsilon.h>
     34 
     35 
     36 namespace fst {
     37 
     38 static const uint32 kEncodeLabels      = 0x0001;
     39 static const uint32 kEncodeWeights     = 0x0002;
     40 static const uint32 kEncodeFlags       = 0x0003;  // All non-internal flags
     41 
     42 static const uint32 kEncodeHasISymbols = 0x0004;  // For internal use
     43 static const uint32 kEncodeHasOSymbols = 0x0008;  // For internal use
     44 
     45 enum EncodeType { ENCODE = 1, DECODE = 2 };
     46 
     47 // Identifies stream data as an encode table (and its endianity)
     48 static const int32 kEncodeMagicNumber = 2129983209;
     49 
     50 
     51 // The following class encapsulates implementation details for the
     52 // encoding and decoding of label/weight tuples used for encoding
     53 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
     54 // stores both the Tuple of encode labels and weights to a unique
     55 // label, and the reverse.
     56 template <class A>  class EncodeTable {
     57  public:
     58   typedef typename A::Label Label;
     59   typedef typename A::Weight Weight;
     60 
     61   // Encoded data consists of arc input/output labels and arc weight
     62   struct Tuple {
     63     Tuple() {}
     64     Tuple(Label ilabel_, Label olabel_, Weight weight_)
     65         : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
     66     Tuple(const Tuple& tuple)
     67         : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
     68 
     69     Label ilabel;
     70     Label olabel;
     71     Weight weight;
     72   };
     73 
     74   // Comparison object for hashing EncodeTable Tuple(s).
     75   class TupleEqual {
     76    public:
     77     bool operator()(const Tuple* x, const Tuple* y) const {
     78       return (x->ilabel == y->ilabel &&
     79               x->olabel == y->olabel &&
     80               x->weight == y->weight);
     81     }
     82   };
     83 
     84   // Hash function for EncodeTabe Tuples. Based on the encode flags
     85   // we either hash the labels, weights or combination of them.
     86   class TupleKey {
     87    public:
     88     TupleKey()
     89         : encode_flags_(kEncodeLabels | kEncodeWeights) {}
     90 
     91     TupleKey(const TupleKey& key)
     92         : encode_flags_(key.encode_flags_) {}
     93 
     94     explicit TupleKey(uint32 encode_flags)
     95         : encode_flags_(encode_flags) {}
     96 
     97     size_t operator()(const Tuple* x) const {
     98       size_t hash = x->ilabel;
     99       const int lshift = 5;
    100       const int rshift = CHAR_BIT * sizeof(size_t) - 5;
    101       if (encode_flags_ & kEncodeLabels)
    102         hash = hash << lshift ^ hash >> rshift ^ x->olabel;
    103       if (encode_flags_ & kEncodeWeights)
    104         hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
    105       return hash;
    106     }
    107 
    108    private:
    109     int32 encode_flags_;
    110   };
    111 
    112   typedef unordered_map<const Tuple*,
    113                    Label,
    114                    TupleKey,
    115                    TupleEqual> EncodeHash;
    116 
    117   explicit EncodeTable(uint32 encode_flags)
    118       : flags_(encode_flags),
    119         encode_hash_(1024, TupleKey(encode_flags)),
    120         isymbols_(0), osymbols_(0) {}
    121 
    122   ~EncodeTable() {
    123     for (size_t i = 0; i < encode_tuples_.size(); ++i) {
    124       delete encode_tuples_[i];
    125     }
    126     delete isymbols_;
    127     delete osymbols_;
    128   }
    129 
    130   // Given an arc encode either input/ouptut labels or input/costs or both
    131   Label Encode(const A &arc) {
    132     const Tuple tuple(arc.ilabel,
    133                       flags_ & kEncodeLabels ? arc.olabel : 0,
    134                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
    135     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    136     if (it == encode_hash_.end()) {
    137       encode_tuples_.push_back(new Tuple(tuple));
    138       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
    139       return encode_tuples_.size();
    140     } else {
    141       return it->second;
    142     }
    143   }
    144 
    145   // Given an arc, look up its encoded label. Returns kNoLabel if not found.
    146   Label GetLabel(const A &arc) const {
    147     const Tuple tuple(arc.ilabel,
    148                       flags_ & kEncodeLabels ? arc.olabel : 0,
    149                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
    150     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    151     if (it == encode_hash_.end()) {
    152       return kNoLabel;
    153     } else {
    154       return it->second;
    155     }
    156   }
    157 
    158   // Given an encode arc Label decode back to input/output labels and costs
    159   const Tuple* Decode(Label key) const {
    160     if (key < 1 || key > encode_tuples_.size()) {
    161       LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
    162       return 0;
    163     }
    164     return encode_tuples_[key - 1];
    165   }
    166 
    167   size_t Size() const { return encode_tuples_.size(); }
    168 
    169   bool Write(ostream &strm, const string &source) const;
    170 
    171   static EncodeTable<A> *Read(istream &strm, const string &source);
    172 
    173   const uint32 flags() const { return flags_ & kEncodeFlags; }
    174 
    175   int RefCount() const { return ref_count_.count(); }
    176   int IncrRefCount() { return ref_count_.Incr(); }
    177   int DecrRefCount() { return ref_count_.Decr(); }
    178 
    179 
    180   SymbolTable *InputSymbols() const { return isymbols_; }
    181 
    182   SymbolTable *OutputSymbols() const { return osymbols_; }
    183 
    184   void SetInputSymbols(const SymbolTable* syms) {
    185     if (isymbols_) delete isymbols_;
    186     if (syms) {
    187       isymbols_ = syms->Copy();
    188       flags_ |= kEncodeHasISymbols;
    189     } else {
    190       isymbols_ = 0;
    191       flags_ &= ~kEncodeHasISymbols;
    192     }
    193   }
    194 
    195   void SetOutputSymbols(const SymbolTable* syms) {
    196     if (osymbols_) delete osymbols_;
    197     if (syms) {
    198       osymbols_ = syms->Copy();
    199       flags_ |= kEncodeHasOSymbols;
    200     } else {
    201       osymbols_ = 0;
    202       flags_ &= ~kEncodeHasOSymbols;
    203     }
    204   }
    205 
    206  private:
    207   uint32 flags_;
    208   vector<Tuple*> encode_tuples_;
    209   EncodeHash encode_hash_;
    210   RefCounter ref_count_;
    211   SymbolTable *isymbols_;       // Pre-encoded ilabel symbol table
    212   SymbolTable *osymbols_;       // Pre-encoded olabel symbol table
    213 
    214   DISALLOW_COPY_AND_ASSIGN(EncodeTable);
    215 };
    216 
    217 template <class A> inline
    218 bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
    219   WriteType(strm, kEncodeMagicNumber);
    220   WriteType(strm, flags_);
    221   int64 size = encode_tuples_.size();
    222   WriteType(strm, size);
    223   for (size_t i = 0;  i < size; ++i) {
    224     const Tuple* tuple = encode_tuples_[i];
    225     WriteType(strm, tuple->ilabel);
    226     WriteType(strm, tuple->olabel);
    227     tuple->weight.Write(strm);
    228   }
    229 
    230   if (flags_ & kEncodeHasISymbols)
    231     isymbols_->Write(strm);
    232 
    233   if (flags_ & kEncodeHasOSymbols)
    234     osymbols_->Write(strm);
    235 
    236   strm.flush();
    237   if (!strm) {
    238     LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
    239     return false;
    240   }
    241   return true;
    242 }
    243 
    244 template <class A> inline
    245 EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
    246   int32 magic_number = 0;
    247   ReadType(strm, &magic_number);
    248   if (magic_number != kEncodeMagicNumber) {
    249     LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
    250     return 0;
    251   }
    252   uint32 flags;
    253   ReadType(strm, &flags);
    254   EncodeTable<A> *table = new EncodeTable<A>(flags);
    255 
    256   int64 size;
    257   ReadType(strm, &size);
    258   if (!strm) {
    259     LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    260     return 0;
    261   }
    262 
    263   for (size_t i = 0; i < size; ++i) {
    264     Tuple* tuple = new Tuple();
    265     ReadType(strm, &tuple->ilabel);
    266     ReadType(strm, &tuple->olabel);
    267     tuple->weight.Read(strm);
    268     if (!strm) {
    269       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    270       return 0;
    271     }
    272     table->encode_tuples_.push_back(tuple);
    273     table->encode_hash_[table->encode_tuples_.back()] =
    274         table->encode_tuples_.size();
    275   }
    276 
    277   if (flags & kEncodeHasISymbols)
    278     table->isymbols_ = SymbolTable::Read(strm, source);
    279 
    280   if (flags & kEncodeHasOSymbols)
    281     table->osymbols_ = SymbolTable::Read(strm, source);
    282 
    283   return table;
    284 }
    285 
    286 
    287 // A mapper to encode/decode weighted transducers. Encoding of an
    288 // Fst is useful for performing classical determinization or minimization
    289 // on a weighted transducer by treating it as an unweighted acceptor over
    290 // encoded labels.
    291 //
    292 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
    293 // This table is shared (and reference counted) between the encoder and
    294 // decoder. A decoder has read only access to the EncodeTable.
    295 //
    296 // The EncodeMapper allows on the fly encoding of the machine. As the
    297 // EncodeTable is generated the same table may by used to decode the machine
    298 // on the fly. For example in the following sequence of operations
    299 //
    300 //  Encode -> Determinize -> Decode
    301 //
    302 // we will use the encoding table generated during the encode step in the
    303 // decode, even though the encoding is not complete.
    304 //
    305 template <class A> class EncodeMapper {
    306   typedef typename A::Weight Weight;
    307   typedef typename A::Label  Label;
    308  public:
    309   EncodeMapper(uint32 flags, EncodeType type)
    310     : flags_(flags),
    311       type_(type),
    312       table_(new EncodeTable<A>(flags)),
    313       error_(false) {}
    314 
    315   EncodeMapper(const EncodeMapper& mapper)
    316       : flags_(mapper.flags_),
    317         type_(mapper.type_),
    318         table_(mapper.table_),
    319         error_(false) {
    320     table_->IncrRefCount();
    321   }
    322 
    323   // Copy constructor but setting the type, typically to DECODE
    324   EncodeMapper(const EncodeMapper& mapper, EncodeType type)
    325       : flags_(mapper.flags_),
    326         type_(type),
    327         table_(mapper.table_),
    328         error_(mapper.error_) {
    329     table_->IncrRefCount();
    330   }
    331 
    332   ~EncodeMapper() {
    333     if (!table_->DecrRefCount()) delete table_;
    334   }
    335 
    336   A operator()(const A &arc);
    337 
    338   MapFinalAction FinalAction() const {
    339     return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
    340                    MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
    341   }
    342 
    343   MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
    344 
    345   MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
    346 
    347   uint64 Properties(uint64 inprops) {
    348     uint64 outprops = inprops;
    349     if (error_) outprops |= kError;
    350 
    351     uint64 mask = kFstProperties;
    352     if (flags_ & kEncodeLabels)
    353       mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
    354     if (flags_ & kEncodeWeights)
    355       mask &= kILabelInvariantProperties & kWeightInvariantProperties &
    356           (type_ == ENCODE ? kAddSuperFinalProperties :
    357            kRmSuperFinalProperties);
    358 
    359     return outprops & mask;
    360   }
    361 
    362   const uint32 flags() const { return flags_; }
    363   const EncodeType type() const { return type_; }
    364   const EncodeTable<A> &table() const { return *table_; }
    365 
    366   bool Write(ostream &strm, const string& source) {
    367     return table_->Write(strm, source);
    368   }
    369 
    370   bool Write(const string& filename) {
    371     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
    372     if (!strm) {
    373       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    374       return false;
    375     }
    376     return Write(strm, filename);
    377   }
    378 
    379   static EncodeMapper<A> *Read(istream &strm,
    380                                const string& source,
    381                                EncodeType type = ENCODE) {
    382     EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
    383     return table ? new EncodeMapper(table->flags(), type, table) : 0;
    384   }
    385 
    386   static EncodeMapper<A> *Read(const string& filename,
    387                                EncodeType type = ENCODE) {
    388     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    389     if (!strm) {
    390       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    391       return NULL;
    392     }
    393     return Read(strm, filename, type);
    394   }
    395 
    396   SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
    397 
    398   SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
    399 
    400   void SetInputSymbols(const SymbolTable* syms) {
    401     table_->SetInputSymbols(syms);
    402   }
    403 
    404   void SetOutputSymbols(const SymbolTable* syms) {
    405     table_->SetOutputSymbols(syms);
    406   }
    407 
    408  private:
    409   uint32 flags_;
    410   EncodeType type_;
    411   EncodeTable<A>* table_;
    412   bool error_;
    413 
    414   explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
    415       : flags_(flags), type_(type), table_(table) {}
    416   void operator=(const EncodeMapper &);  // Disallow.
    417 };
    418 
    419 template <class A> inline
    420 A EncodeMapper<A>::operator()(const A &arc) {
    421   if (type_ == ENCODE) {  // labels and/or weights to single label
    422     if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
    423         (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
    424          arc.weight == Weight::Zero())) {
    425       return arc;
    426     } else {
    427       Label label = table_->Encode(arc);
    428       return A(label,
    429                flags_ & kEncodeLabels ? label : arc.olabel,
    430                flags_ & kEncodeWeights ? Weight::One() : arc.weight,
    431                arc.nextstate);
    432     }
    433   } else {  // type_ == DECODE
    434     if (arc.nextstate == kNoStateId) {
    435       return arc;
    436     } else {
    437       if (arc.ilabel == 0) return arc;
    438       if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
    439         FSTERROR() << "EncodeMapper: Label-encoded arc has different "
    440             "input and output labels";
    441         error_ = true;
    442       }
    443       if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
    444         FSTERROR() <<
    445             "EncodeMapper: Weight-encoded arc has non-trivial weight";
    446         error_ = true;
    447       }
    448       const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
    449       if (!tuple) {
    450         FSTERROR() << "EncodeMapper: decode failed";
    451         error_ = true;
    452         return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
    453       } else {
    454         return A(tuple->ilabel,
    455                  flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
    456                  flags_ & kEncodeWeights ? tuple->weight : arc.weight,
    457                  arc.nextstate);
    458       }
    459     }
    460   }
    461 }
    462 
    463 
    464 // Complexity: O(nstates + narcs)
    465 template<class A> inline
    466 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
    467   mapper->SetInputSymbols(fst->InputSymbols());
    468   mapper->SetOutputSymbols(fst->OutputSymbols());
    469   ArcMap(fst, mapper);
    470 }
    471 
    472 template<class A> inline
    473 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
    474   ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
    475   RmFinalEpsilon(fst);
    476   fst->SetInputSymbols(mapper.InputSymbols());
    477   fst->SetOutputSymbols(mapper.OutputSymbols());
    478 }
    479 
    480 
    481 // On the fly label and/or weight encoding of input Fst
    482 //
    483 // Complexity:
    484 // - Constructor: O(1)
    485 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    486 //   time to visit an input state or arc.
    487 template <class A>
    488 class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
    489  public:
    490   typedef A Arc;
    491   typedef EncodeMapper<A> C;
    492   typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
    493   using ImplToFst<Impl>::GetImpl;
    494 
    495   EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
    496       : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
    497     encoder->SetInputSymbols(fst.InputSymbols());
    498     encoder->SetOutputSymbols(fst.OutputSymbols());
    499   }
    500 
    501   EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    502       : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
    503 
    504   // See Fst<>::Copy() for doc.
    505   EncodeFst(const EncodeFst<A> &fst, bool copy = false)
    506       : ArcMapFst<A, A, C>(fst, copy) {}
    507 
    508   // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
    509   virtual EncodeFst<A> *Copy(bool safe = false) const {
    510     if (safe) {
    511       FSTERROR() << "EncodeFst::Copy(true): not allowed.";
    512       GetImpl()->SetProperties(kError, kError);
    513     }
    514     return new EncodeFst(*this);
    515   }
    516 };
    517 
    518 
    519 // On the fly label and/or weight encoding of input Fst
    520 //
    521 // Complexity:
    522 // - Constructor: O(1)
    523 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    524 //   time to visit an input state or arc.
    525 template <class A>
    526 class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
    527  public:
    528   typedef A Arc;
    529   typedef EncodeMapper<A> C;
    530   typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
    531   using ImplToFst<Impl>::GetImpl;
    532 
    533   DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    534       : ArcMapFst<A, A, C>(fst,
    535                             EncodeMapper<A>(encoder, DECODE),
    536                             ArcMapFstOptions()) {
    537     GetImpl()->SetInputSymbols(encoder.InputSymbols());
    538     GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
    539   }
    540 
    541   // See Fst<>::Copy() for doc.
    542   DecodeFst(const DecodeFst<A> &fst, bool safe = false)
    543       : ArcMapFst<A, A, C>(fst, safe) {}
    544 
    545   // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
    546   virtual DecodeFst<A> *Copy(bool safe = false) const {
    547     return new DecodeFst(*this, safe);
    548   }
    549 };
    550 
    551 
    552 // Specialization for EncodeFst.
    553 template <class A>
    554 class StateIterator< EncodeFst<A> >
    555     : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
    556  public:
    557   explicit StateIterator(const EncodeFst<A> &fst)
    558       : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
    559 };
    560 
    561 
    562 // Specialization for EncodeFst.
    563 template <class A>
    564 class ArcIterator< EncodeFst<A> >
    565     : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
    566  public:
    567   ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
    568       : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    569 };
    570 
    571 
    572 // Specialization for DecodeFst.
    573 template <class A>
    574 class StateIterator< DecodeFst<A> >
    575     : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
    576  public:
    577   explicit StateIterator(const DecodeFst<A> &fst)
    578       : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
    579 };
    580 
    581 
    582 // Specialization for DecodeFst.
    583 template <class A>
    584 class ArcIterator< DecodeFst<A> >
    585     : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
    586  public:
    587   ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
    588       : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    589 };
    590 
    591 
    592 // Useful aliases when using StdArc.
    593 typedef EncodeFst<StdArc> StdEncodeFst;
    594 
    595 typedef DecodeFst<StdArc> StdDecodeFst;
    596 
    597 }  // namespace fst
    598 
    599 #endif  // FST_LIB_ENCODE_H__
    600