Home | History | Annotate | Download | only in lib
      1 // encode.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Class to encode and decoder an fst.
     18 
     19 #ifndef FST_LIB_ENCODE_H__
     20 #define FST_LIB_ENCODE_H__
     21 
     22 #include "fst/lib/map.h"
     23 #include "fst/lib/rmfinalepsilon.h"
     24 
     25 namespace fst {
     26 
     27 static const uint32 kEncodeLabels = 0x00001;
     28 static const uint32 kEncodeWeights  = 0x00002;
     29 
     30 enum EncodeType { ENCODE = 1, DECODE = 2 };
     31 
     32 // Identifies stream data as an encode table (and its endianity)
     33 static const int32 kEncodeMagicNumber = 2129983209;
     34 
     35 
     36 // The following class encapsulates implementation details for the
     37 // encoding and decoding of label/weight tuples used for encoding
     38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
     39 // stores both the Tuple of encode labels and weights to a unique
     40 // label, and the reverse.
     41 template <class A>  class EncodeTable {
     42  public:
     43   typedef typename A::Label Label;
     44   typedef typename A::Weight Weight;
     45 
     46   // Encoded data consists of arc input/output labels and arc weight
     47   struct Tuple {
     48     Tuple() {}
     49     Tuple(Label ilabel_, Label olabel_, Weight weight_)
     50         : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
     51     Tuple(const Tuple& tuple)
     52         : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
     53 
     54     Label ilabel;
     55     Label olabel;
     56     Weight weight;
     57   };
     58 
     59   // Comparison object for hashing EncodeTable Tuple(s).
     60   class TupleEqual {
     61    public:
     62     bool operator()(const Tuple* x, const Tuple* y) const {
     63       return (x->ilabel == y->ilabel &&
     64               x->olabel == y->olabel &&
     65               x->weight == y->weight);
     66     }
     67   };
     68 
     69   // Hash function for EncodeTabe Tuples. Based on the encode flags
     70   // we either hash the labels, weights or compbination of them.
     71   class TupleKey {
     72     static const int kPrime = 7853;
     73    public:
     74     TupleKey()
     75         : encode_flags_(kEncodeLabels | kEncodeWeights) {}
     76 
     77     TupleKey(const TupleKey& key)
     78         : encode_flags_(key.encode_flags_) {}
     79 
     80     explicit TupleKey(uint32 encode_flags)
     81         : encode_flags_(encode_flags) {}
     82 
     83     size_t operator()(const Tuple* x) const {
     84       int lshift = x->ilabel % kPrime;
     85       int rshift = sizeof(size_t) - lshift;
     86       size_t hash = x->ilabel << lshift;
     87       if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift;
     88       if (encode_flags_ & kEncodeWeights)  hash ^= x->weight.Hash();
     89       return hash;
     90     }
     91 
     92    private:
     93     int32 encode_flags_;
     94   };
     95 
     96   typedef hash_map<const Tuple*,
     97                    Label,
     98                    TupleKey,
     99                    TupleEqual> EncodeHash;
    100 
    101   explicit EncodeTable(uint32 encode_flags)
    102       : flags_(encode_flags),
    103         encode_hash_(1024, TupleKey(encode_flags)) {}
    104 
    105   ~EncodeTable() {
    106     for (size_t i = 0; i < encode_tuples_.size(); ++i) {
    107       delete encode_tuples_[i];
    108     }
    109   }
    110 
    111   // Given an arc encode either input/ouptut labels or input/costs or both
    112   Label Encode(const A &arc) {
    113     const Tuple tuple(arc.ilabel,
    114                       flags_ & kEncodeLabels ? arc.olabel : 0,
    115                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
    116     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    117     if (it == encode_hash_.end()) {
    118       encode_tuples_.push_back(new Tuple(tuple));
    119       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
    120       return encode_tuples_.size();
    121     } else {
    122       return it->second;
    123     }
    124   }
    125 
    126   // Given an encode arc Label decode back to input/output labels and costs
    127   const Tuple* Decode(Label key) {
    128     return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0;
    129   }
    130 
    131   bool Write(ostream &strm, const string &source) const {
    132     WriteType(strm, kEncodeMagicNumber);
    133     WriteType(strm, flags_);
    134     int64 size = encode_tuples_.size();
    135     WriteType(strm, size);
    136     for (size_t i = 0;  i < size; ++i) {
    137       const Tuple* tuple = encode_tuples_[i];
    138       WriteType(strm, tuple->ilabel);
    139       WriteType(strm, tuple->olabel);
    140       tuple->weight.Write(strm);
    141     }
    142     strm.flush();
    143     if (!strm)
    144       LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
    145     return strm;
    146   }
    147 
    148   bool Read(istream &strm, const string &source) {
    149     encode_tuples_.clear();
    150     encode_hash_.clear();
    151     int32 magic_number = 0;
    152     ReadType(strm, &magic_number);
    153     if (magic_number != kEncodeMagicNumber) {
    154       LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
    155       return false;
    156     }
    157     ReadType(strm, &flags_);
    158     int64 size;
    159     ReadType(strm, &size);
    160     if (!strm) {
    161       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    162       return false;
    163     }
    164     for (size_t i = 0; i < size; ++i) {
    165       Tuple* tuple = new Tuple();
    166       ReadType(strm, &tuple->ilabel);
    167       ReadType(strm, &tuple->olabel);
    168       tuple->weight.Read(strm);
    169       encode_tuples_.push_back(tuple);
    170       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
    171     }
    172     if (!strm)
    173       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    174     return strm;
    175   }
    176 
    177   const uint32 flags() const { return flags_; }
    178  private:
    179   uint32 flags_;
    180   vector<Tuple*> encode_tuples_;
    181   EncodeHash encode_hash_;
    182 
    183   DISALLOW_EVIL_CONSTRUCTORS(EncodeTable);
    184 };
    185 
    186 
    187 // A mapper to encode/decode weighted transducers. Encoding of an
    188 // Fst is useful for performing classical determinization or minimization
    189 // on a weighted transducer by treating it as an unweighted acceptor over
    190 // encoded labels.
    191 //
    192 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
    193 // This table is shared (and reference counted) between the encoder and
    194 // decoder. A decoder has read only access to the EncodeTable.
    195 //
    196 // The EncodeMapper allows on the fly encoding of the machine. As the
    197 // EncodeTable is generated the same table may by used to decode the machine
    198 // on the fly. For example in the following sequence of operations
    199 //
    200 //  Encode -> Determinize -> Decode
    201 //
    202 // we will use the encoding table generated during the encode step in the
    203 // decode, even though the encoding is not complete.
    204 //
    205 template <class A> class EncodeMapper {
    206   typedef typename A::Weight Weight;
    207   typedef typename A::Label  Label;
    208  public:
    209   EncodeMapper(uint32 flags, EncodeType type)
    210     : ref_count_(1), flags_(flags), type_(type),
    211       table_(new EncodeTable<A>(flags)) {}
    212 
    213   EncodeMapper(const EncodeMapper& mapper)
    214       : ref_count_(mapper.ref_count_ + 1),
    215         flags_(mapper.flags_),
    216         type_(mapper.type_),
    217         table_(mapper.table_) { }
    218 
    219   // Copy constructor but setting the type, typically to DECODE
    220   EncodeMapper(const EncodeMapper& mapper, EncodeType type)
    221       : ref_count_(mapper.ref_count_ + 1),
    222         flags_(mapper.flags_),
    223         type_(type),
    224         table_(mapper.table_) { }
    225 
    226   ~EncodeMapper() {
    227     if (--ref_count_ == 0) delete table_;
    228   }
    229 
    230   A operator()(const A &arc) {
    231     if (type_ == ENCODE) {  // labels and/or weights to single label
    232       if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
    233           (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
    234            arc.weight == Weight::Zero())) {
    235         return arc;
    236       } else {
    237         Label label = table_->Encode(arc);
    238         return A(label,
    239                  flags_ & kEncodeLabels ? label : arc.olabel,
    240                  flags_ & kEncodeWeights ? Weight::One() : arc.weight,
    241                  arc.nextstate);
    242       }
    243     } else {
    244       if (arc.nextstate == kNoStateId) {
    245         return arc;
    246       } else {
    247         const typename EncodeTable<A>::Tuple* tuple =
    248           table_->Decode(arc.ilabel);
    249         return A(tuple->ilabel,
    250                  flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
    251                  flags_ & kEncodeWeights ? tuple->weight : arc.weight,
    252                  arc.nextstate);;
    253       }
    254     }
    255   }
    256 
    257   uint64 Properties(uint64 props) {
    258     uint64 mask = kFstProperties;
    259     if (flags_ & kEncodeLabels)
    260       mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
    261     if (flags_ & kEncodeWeights)
    262       mask &= kILabelInvariantProperties & kWeightInvariantProperties &
    263           (type_ == ENCODE ? kAddSuperFinalProperties :
    264            kRmSuperFinalProperties);
    265     return props & mask;
    266   }
    267 
    268 
    269   MapFinalAction FinalAction() const {
    270     return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
    271                    MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
    272   }
    273 
    274   const uint32 flags() const { return flags_; }
    275   const EncodeType type() const { return type_; }
    276 
    277   bool Write(ostream &strm, const string& source) {
    278     return table_->Write(strm, source);
    279   }
    280 
    281   bool Write(const string& filename) {
    282     ofstream strm(filename.c_str());
    283     if (!strm) {
    284       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    285       return false;
    286     }
    287     return Write(strm, filename);
    288   }
    289 
    290   static EncodeMapper<A> *Read(istream &strm,
    291                                const string& source, EncodeType type) {
    292     EncodeTable<A> *table = new EncodeTable<A>(0);
    293     bool r = table->Read(strm, source);
    294     return r ? new EncodeMapper(table->flags(), type, table) : 0;
    295   }
    296 
    297   static EncodeMapper<A> *Read(const string& filename, EncodeType type) {
    298     ifstream strm(filename.c_str());
    299     if (!strm) {
    300       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    301       return false;
    302     }
    303     return Read(strm, filename, type);
    304   }
    305 
    306  private:
    307   uint32  ref_count_;
    308   uint32  flags_;
    309   EncodeType type_;
    310   EncodeTable<A>* table_;
    311 
    312   explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
    313       : ref_count_(1), flags_(flags), type_(type), table_(table) {}
    314   void operator=(const EncodeMapper &);  // Disallow.
    315 };
    316 
    317 
    318 // Complexity: O(nstates + narcs)
    319 template<class A> inline
    320 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
    321   Map(fst, mapper);
    322 }
    323 
    324 
    325 template<class A> inline
    326 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
    327   Map(fst, EncodeMapper<A>(mapper, DECODE));
    328   RmFinalEpsilon(fst);
    329 }
    330 
    331 
    332 // On the fly label and/or weight encoding of input Fst
    333 //
    334 // Complexity:
    335 // - Constructor: O(1)
    336 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    337 //   time to visit an input state or arc.
    338 template <class A>
    339 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
    340  public:
    341   typedef A Arc;
    342   typedef EncodeMapper<A> C;
    343 
    344   EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
    345       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
    346 
    347   EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    348       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
    349 
    350   EncodeFst(const EncodeFst<A> &fst)
    351       : MapFst<A, A, C>(fst) {}
    352 
    353   virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); }
    354 };
    355 
    356 
    357 // On the fly label and/or weight encoding of input Fst
    358 //
    359 // Complexity:
    360 // - Constructor: O(1)
    361 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    362 //   time to visit an input state or arc.
    363 template <class A>
    364 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
    365  public:
    366   typedef A Arc;
    367   typedef EncodeMapper<A> C;
    368 
    369   DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    370       : MapFst<A, A, C>(fst,
    371                             EncodeMapper<A>(encoder, DECODE),
    372                             MapFstOptions()) {}
    373 
    374   DecodeFst(const EncodeFst<A> &fst)
    375       : MapFst<A, A, C>(fst) {}
    376 
    377   virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); }
    378 };
    379 
    380 
    381 // Specialization for EncodeFst.
    382 template <class A>
    383 class StateIterator< EncodeFst<A> >
    384     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
    385  public:
    386   explicit StateIterator(const EncodeFst<A> &fst)
    387       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
    388 };
    389 
    390 
    391 // Specialization for EncodeFst.
    392 template <class A>
    393 class ArcIterator< EncodeFst<A> >
    394     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
    395  public:
    396   ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
    397       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    398 };
    399 
    400 
    401 // Specialization for DecodeFst.
    402 template <class A>
    403 class StateIterator< DecodeFst<A> >
    404     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
    405  public:
    406   explicit StateIterator(const DecodeFst<A> &fst)
    407       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
    408 };
    409 
    410 
    411 // Specialization for DecodeFst.
    412 template <class A>
    413 class ArcIterator< DecodeFst<A> >
    414     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
    415  public:
    416   ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
    417       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    418 };
    419 
    420 
    421 // Useful aliases when using StdArc.
    422 typedef EncodeFst<StdArc> StdEncodeFst;
    423 
    424 typedef DecodeFst<StdArc> StdDecodeFst;
    425 
    426 }
    427 
    428 #endif  // FST_LIB_ENCODE_H__
    429