Home | History | Annotate | Download | only in lib
      1 // encode.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Class to encode and decoder an fst.
     18 
     19 #ifndef FST_LIB_ENCODE_H__
     20 #define FST_LIB_ENCODE_H__
     21 
     22 #include "fst/lib/map.h"
     23 #include "fst/lib/rmfinalepsilon.h"
     24 
     25 namespace fst {
     26 
     27 static const uint32 kEncodeLabels = 0x00001;
     28 static const uint32 kEncodeWeights  = 0x00002;
     29 
     30 enum EncodeType { ENCODE = 1, DECODE = 2 };
     31 
     32 // Identifies stream data as an encode table (and its endianity)
     33 static const int32 kEncodeMagicNumber = 2129983209;
     34 
     35 
     36 // The following class encapsulates implementation details for the
     37 // encoding and decoding of label/weight tuples used for encoding
     38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
     39 // stores both the Tuple of encode labels and weights to a unique
     40 // label, and the reverse.
     41 template <class A>  class EncodeTable {
     42  public:
     43   typedef typename A::Label Label;
     44   typedef typename A::Weight Weight;
     45 
     46   // Encoded data consists of arc input/output labels and arc weight
     47   struct Tuple {
     48     Tuple() {}
     49     Tuple(Label ilabel_, Label olabel_, Weight weight_)
     50         : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
     51     Tuple(const Tuple& tuple)
     52         : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
     53 
     54     Label ilabel;
     55     Label olabel;
     56     Weight weight;
     57   };
     58 
     59   // Comparison object for hashing EncodeTable Tuple(s).
     60   class TupleEqual {
     61    public:
     62     bool operator()(const Tuple* x, const Tuple* y) const {
     63       return (x->ilabel == y->ilabel &&
     64               x->olabel == y->olabel &&
     65               x->weight == y->weight);
     66     }
     67   };
     68 
     69   // Hash function for EncodeTabe Tuples. Based on the encode flags
     70   // we either hash the labels, weights or compbination of them.
     71   class TupleKey {
     72     static const int kPrime = 7853;
     73    public:
     74     TupleKey()
     75         : encode_flags_(kEncodeLabels | kEncodeWeights) {}
     76 
     77     TupleKey(const TupleKey& key)
     78         : encode_flags_(key.encode_flags_) {}
     79 
     80     explicit TupleKey(uint32 encode_flags)
     81         : encode_flags_(encode_flags) {}
     82 
     83     size_t operator()(const Tuple* x) const {
     84       int lshift = x->ilabel % kPrime;
     85       int rshift = sizeof(size_t) - lshift;
     86       size_t hash = x->ilabel << lshift;
     87       if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift;
     88       if (encode_flags_ & kEncodeWeights)  hash ^= x->weight.Hash();
     89       return hash;
     90     }
     91 
     92    private:
     93     int32 encode_flags_;
     94   };
     95 
     96   typedef std::unordered_map<const Tuple*, Label, TupleKey, TupleEqual> EncodeHash;
     97 
     98   explicit EncodeTable(uint32 encode_flags)
     99       : flags_(encode_flags),
    100         encode_hash_(1024, TupleKey(encode_flags)) {}
    101 
    102   ~EncodeTable() {
    103     for (size_t i = 0; i < encode_tuples_.size(); ++i) {
    104       delete encode_tuples_[i];
    105     }
    106   }
    107 
    108   // Given an arc encode either input/ouptut labels or input/costs or both
    109   Label Encode(const A &arc) {
    110     const Tuple tuple(arc.ilabel,
    111                       flags_ & kEncodeLabels ? arc.olabel : 0,
    112                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
    113     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    114     if (it == encode_hash_.end()) {
    115       encode_tuples_.push_back(new Tuple(tuple));
    116       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
    117       return encode_tuples_.size();
    118     } else {
    119       return it->second;
    120     }
    121   }
    122 
    123   // Given an encode arc Label decode back to input/output labels and costs
    124   const Tuple* Decode(Label key) {
    125     return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0;
    126   }
    127 
    128   bool Write(ostream &strm, const string &source) const {
    129     WriteType(strm, kEncodeMagicNumber);
    130     WriteType(strm, flags_);
    131     int64 size = encode_tuples_.size();
    132     WriteType(strm, size);
    133     for (size_t i = 0;  i < size; ++i) {
    134       const Tuple* tuple = encode_tuples_[i];
    135       WriteType(strm, tuple->ilabel);
    136       WriteType(strm, tuple->olabel);
    137       tuple->weight.Write(strm);
    138     }
    139     strm.flush();
    140     if (!strm) {
    141       LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
    142       return false;
    143     }
    144     return true;
    145   }
    146 
    147   bool Read(istream &strm, const string &source) {
    148     encode_tuples_.clear();
    149     encode_hash_.clear();
    150     int32 magic_number = 0;
    151     ReadType(strm, &magic_number);
    152     if (magic_number != kEncodeMagicNumber) {
    153       LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
    154       return false;
    155     }
    156     ReadType(strm, &flags_);
    157     int64 size;
    158     ReadType(strm, &size);
    159     if (!strm) {
    160       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    161       return false;
    162     }
    163     for (size_t i = 0; i < size; ++i) {
    164       Tuple* tuple = new Tuple();
    165       ReadType(strm, &tuple->ilabel);
    166       ReadType(strm, &tuple->olabel);
    167       tuple->weight.Read(strm);
    168       encode_tuples_.push_back(tuple);
    169       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
    170     }
    171     if (!strm) {
    172       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    173       return false;
    174     }
    175     return true;
    176   }
    177 
    178   uint32 flags() const { return flags_; }
    179  private:
    180   uint32 flags_;
    181   vector<Tuple*> encode_tuples_;
    182   EncodeHash encode_hash_;
    183 
    184   DISALLOW_EVIL_CONSTRUCTORS(EncodeTable);
    185 };
    186 
    187 
    188 // A mapper to encode/decode weighted transducers. Encoding of an
    189 // Fst is useful for performing classical determinization or minimization
    190 // on a weighted transducer by treating it as an unweighted acceptor over
    191 // encoded labels.
    192 //
    193 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
    194 // This table is shared (and reference counted) between the encoder and
    195 // decoder. A decoder has read only access to the EncodeTable.
    196 //
    197 // The EncodeMapper allows on the fly encoding of the machine. As the
    198 // EncodeTable is generated the same table may by used to decode the machine
    199 // on the fly. For example in the following sequence of operations
    200 //
    201 //  Encode -> Determinize -> Decode
    202 //
    203 // we will use the encoding table generated during the encode step in the
    204 // decode, even though the encoding is not complete.
    205 //
    206 template <class A> class EncodeMapper {
    207   typedef typename A::Weight Weight;
    208   typedef typename A::Label  Label;
    209  public:
    210   EncodeMapper(uint32 flags, EncodeType type)
    211     : ref_count_(1), flags_(flags), type_(type),
    212       table_(new EncodeTable<A>(flags)) {}
    213 
    214   EncodeMapper(const EncodeMapper& mapper)
    215       : ref_count_(mapper.ref_count_ + 1),
    216         flags_(mapper.flags_),
    217         type_(mapper.type_),
    218         table_(mapper.table_) { }
    219 
    220   // Copy constructor but setting the type, typically to DECODE
    221   EncodeMapper(const EncodeMapper& mapper, EncodeType type)
    222       : ref_count_(mapper.ref_count_ + 1),
    223         flags_(mapper.flags_),
    224         type_(type),
    225         table_(mapper.table_) { }
    226 
    227   ~EncodeMapper() {
    228     if (--ref_count_ == 0) delete table_;
    229   }
    230 
    231   A operator()(const A &arc) {
    232     if (type_ == ENCODE) {  // labels and/or weights to single label
    233       if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
    234           (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
    235            arc.weight == Weight::Zero())) {
    236         return arc;
    237       } else {
    238         Label label = table_->Encode(arc);
    239         return A(label,
    240                  flags_ & kEncodeLabels ? label : arc.olabel,
    241                  flags_ & kEncodeWeights ? Weight::One() : arc.weight,
    242                  arc.nextstate);
    243       }
    244     } else {
    245       if (arc.nextstate == kNoStateId) {
    246         return arc;
    247       } else {
    248         const typename EncodeTable<A>::Tuple* tuple =
    249           table_->Decode(arc.ilabel);
    250         return A(tuple->ilabel,
    251                  flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
    252                  flags_ & kEncodeWeights ? tuple->weight : arc.weight,
    253                  arc.nextstate);;
    254       }
    255     }
    256   }
    257 
    258   uint64 Properties(uint64 props) {
    259     uint64 mask = kFstProperties;
    260     if (flags_ & kEncodeLabels)
    261       mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
    262     if (flags_ & kEncodeWeights)
    263       mask &= kILabelInvariantProperties & kWeightInvariantProperties &
    264           (type_ == ENCODE ? kAddSuperFinalProperties :
    265            kRmSuperFinalProperties);
    266     return props & mask;
    267   }
    268 
    269 
    270   MapFinalAction FinalAction() const {
    271     return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
    272                    MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
    273   }
    274 
    275   uint32 flags() const { return flags_; }
    276   EncodeType type() const { return type_; }
    277 
    278   bool Write(ostream &strm, const string& source) {
    279     return table_->Write(strm, source);
    280   }
    281 
    282   bool Write(const string& filename) {
    283     ofstream strm(filename.c_str());
    284     if (!strm) {
    285       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    286       return false;
    287     }
    288     return Write(strm, filename);
    289   }
    290 
    291   static EncodeMapper<A> *Read(istream &strm,
    292                                const string& source, EncodeType type) {
    293     EncodeTable<A> *table = new EncodeTable<A>(0);
    294     bool r = table->Read(strm, source);
    295     return r ? new EncodeMapper(table->flags(), type, table) : 0;
    296   }
    297 
    298   static EncodeMapper<A> *Read(const string& filename, EncodeType type) {
    299     ifstream strm(filename.c_str());
    300     if (!strm) {
    301       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
    302       return false;
    303     }
    304     return Read(strm, filename, type);
    305   }
    306 
    307  private:
    308   uint32  ref_count_;
    309   uint32  flags_;
    310   EncodeType type_;
    311   EncodeTable<A>* table_;
    312 
    313   explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
    314       : ref_count_(1), flags_(flags), type_(type), table_(table) {}
    315   void operator=(const EncodeMapper &);  // Disallow.
    316 };
    317 
    318 
    319 // Complexity: O(nstates + narcs)
    320 template<class A> inline
    321 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
    322   Map(fst, mapper);
    323 }
    324 
    325 
    326 template<class A> inline
    327 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
    328   Map(fst, EncodeMapper<A>(mapper, DECODE));
    329   RmFinalEpsilon(fst);
    330 }
    331 
    332 
    333 // On the fly label and/or weight encoding of input Fst
    334 //
    335 // Complexity:
    336 // - Constructor: O(1)
    337 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    338 //   time to visit an input state or arc.
    339 template <class A>
    340 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
    341  public:
    342   typedef A Arc;
    343   typedef EncodeMapper<A> C;
    344 
    345   EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
    346       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
    347 
    348   EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    349       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
    350 
    351   EncodeFst(const EncodeFst<A> &fst)
    352       : MapFst<A, A, C>(fst) {}
    353 
    354   virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); }
    355 };
    356 
    357 
    358 // On the fly label and/or weight encoding of input Fst
    359 //
    360 // Complexity:
    361 // - Constructor: O(1)
    362 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
    363 //   time to visit an input state or arc.
    364 template <class A>
    365 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
    366  public:
    367   typedef A Arc;
    368   typedef EncodeMapper<A> C;
    369 
    370   DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
    371       : MapFst<A, A, C>(fst,
    372                             EncodeMapper<A>(encoder, DECODE),
    373                             MapFstOptions()) {}
    374 
    375   DecodeFst(const EncodeFst<A> &fst)
    376       : MapFst<A, A, C>(fst) {}
    377 
    378   virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); }
    379 };
    380 
    381 
    382 // Specialization for EncodeFst.
    383 template <class A>
    384 class StateIterator< EncodeFst<A> >
    385     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
    386  public:
    387   explicit StateIterator(const EncodeFst<A> &fst)
    388       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
    389 };
    390 
    391 
    392 // Specialization for EncodeFst.
    393 template <class A>
    394 class ArcIterator< EncodeFst<A> >
    395     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
    396  public:
    397   ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
    398       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    399 };
    400 
    401 
    402 // Specialization for DecodeFst.
    403 template <class A>
    404 class StateIterator< DecodeFst<A> >
    405     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
    406  public:
    407   explicit StateIterator(const DecodeFst<A> &fst)
    408       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
    409 };
    410 
    411 
    412 // Specialization for DecodeFst.
    413 template <class A>
    414 class ArcIterator< DecodeFst<A> >
    415     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
    416  public:
    417   ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
    418       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
    419 };
    420 
    421 
    422 // Useful aliases when using StdArc.
    423 typedef EncodeFst<StdArc> StdEncodeFst;
    424 
    425 typedef DecodeFst<StdArc> StdDecodeFst;
    426 
    427 }
    428 
    429 #endif  // FST_LIB_ENCODE_H__
    430