1 // encode.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Class to encode and decoder an fst. 18 19 #ifndef FST_LIB_ENCODE_H__ 20 #define FST_LIB_ENCODE_H__ 21 22 #include "fst/lib/map.h" 23 #include "fst/lib/rmfinalepsilon.h" 24 25 namespace fst { 26 27 static const uint32 kEncodeLabels = 0x00001; 28 static const uint32 kEncodeWeights = 0x00002; 29 30 enum EncodeType { ENCODE = 1, DECODE = 2 }; 31 32 // Identifies stream data as an encode table (and its endianity) 33 static const int32 kEncodeMagicNumber = 2129983209; 34 35 36 // The following class encapsulates implementation details for the 37 // encoding and decoding of label/weight tuples used for encoding 38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it 39 // stores both the Tuple of encode labels and weights to a unique 40 // label, and the reverse. 41 template <class A> class EncodeTable { 42 public: 43 typedef typename A::Label Label; 44 typedef typename A::Weight Weight; 45 46 // Encoded data consists of arc input/output labels and arc weight 47 struct Tuple { 48 Tuple() {} 49 Tuple(Label ilabel_, Label olabel_, Weight weight_) 50 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} 51 Tuple(const Tuple& tuple) 52 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} 53 54 Label ilabel; 55 Label olabel; 56 Weight weight; 57 }; 58 59 // Comparison object for hashing EncodeTable Tuple(s). 60 class TupleEqual { 61 public: 62 bool operator()(const Tuple* x, const Tuple* y) const { 63 return (x->ilabel == y->ilabel && 64 x->olabel == y->olabel && 65 x->weight == y->weight); 66 } 67 }; 68 69 // Hash function for EncodeTabe Tuples. Based on the encode flags 70 // we either hash the labels, weights or compbination of them. 71 class TupleKey { 72 static const int kPrime = 7853; 73 public: 74 TupleKey() 75 : encode_flags_(kEncodeLabels | kEncodeWeights) {} 76 77 TupleKey(const TupleKey& key) 78 : encode_flags_(key.encode_flags_) {} 79 80 explicit TupleKey(uint32 encode_flags) 81 : encode_flags_(encode_flags) {} 82 83 size_t operator()(const Tuple* x) const { 84 int lshift = x->ilabel % kPrime; 85 int rshift = sizeof(size_t) - lshift; 86 size_t hash = x->ilabel << lshift; 87 if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift; 88 if (encode_flags_ & kEncodeWeights) hash ^= x->weight.Hash(); 89 return hash; 90 } 91 92 private: 93 int32 encode_flags_; 94 }; 95 96 typedef std::unordered_map<const Tuple*, Label, TupleKey, TupleEqual> EncodeHash; 97 98 explicit EncodeTable(uint32 encode_flags) 99 : flags_(encode_flags), 100 encode_hash_(1024, TupleKey(encode_flags)) {} 101 102 ~EncodeTable() { 103 for (size_t i = 0; i < encode_tuples_.size(); ++i) { 104 delete encode_tuples_[i]; 105 } 106 } 107 108 // Given an arc encode either input/ouptut labels or input/costs or both 109 Label Encode(const A &arc) { 110 const Tuple tuple(arc.ilabel, 111 flags_ & kEncodeLabels ? arc.olabel : 0, 112 flags_ & kEncodeWeights ? arc.weight : Weight::One()); 113 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); 114 if (it == encode_hash_.end()) { 115 encode_tuples_.push_back(new Tuple(tuple)); 116 encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); 117 return encode_tuples_.size(); 118 } else { 119 return it->second; 120 } 121 } 122 123 // Given an encode arc Label decode back to input/output labels and costs 124 const Tuple* Decode(Label key) { 125 return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0; 126 } 127 128 bool Write(ostream &strm, const string &source) const { 129 WriteType(strm, kEncodeMagicNumber); 130 WriteType(strm, flags_); 131 int64 size = encode_tuples_.size(); 132 WriteType(strm, size); 133 for (size_t i = 0; i < size; ++i) { 134 const Tuple* tuple = encode_tuples_[i]; 135 WriteType(strm, tuple->ilabel); 136 WriteType(strm, tuple->olabel); 137 tuple->weight.Write(strm); 138 } 139 strm.flush(); 140 if (!strm) { 141 LOG(ERROR) << "EncodeTable::Write: write failed: " << source; 142 return false; 143 } 144 return true; 145 } 146 147 bool Read(istream &strm, const string &source) { 148 encode_tuples_.clear(); 149 encode_hash_.clear(); 150 int32 magic_number = 0; 151 ReadType(strm, &magic_number); 152 if (magic_number != kEncodeMagicNumber) { 153 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; 154 return false; 155 } 156 ReadType(strm, &flags_); 157 int64 size; 158 ReadType(strm, &size); 159 if (!strm) { 160 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 161 return false; 162 } 163 for (size_t i = 0; i < size; ++i) { 164 Tuple* tuple = new Tuple(); 165 ReadType(strm, &tuple->ilabel); 166 ReadType(strm, &tuple->olabel); 167 tuple->weight.Read(strm); 168 encode_tuples_.push_back(tuple); 169 encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); 170 } 171 if (!strm) { 172 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 173 return false; 174 } 175 return true; 176 } 177 178 uint32 flags() const { return flags_; } 179 private: 180 uint32 flags_; 181 vector<Tuple*> encode_tuples_; 182 EncodeHash encode_hash_; 183 184 DISALLOW_EVIL_CONSTRUCTORS(EncodeTable); 185 }; 186 187 188 // A mapper to encode/decode weighted transducers. Encoding of an 189 // Fst is useful for performing classical determinization or minimization 190 // on a weighted transducer by treating it as an unweighted acceptor over 191 // encoded labels. 192 // 193 // The Encode mapper stores the encoding in a local hash table (EncodeTable) 194 // This table is shared (and reference counted) between the encoder and 195 // decoder. A decoder has read only access to the EncodeTable. 196 // 197 // The EncodeMapper allows on the fly encoding of the machine. As the 198 // EncodeTable is generated the same table may by used to decode the machine 199 // on the fly. For example in the following sequence of operations 200 // 201 // Encode -> Determinize -> Decode 202 // 203 // we will use the encoding table generated during the encode step in the 204 // decode, even though the encoding is not complete. 205 // 206 template <class A> class EncodeMapper { 207 typedef typename A::Weight Weight; 208 typedef typename A::Label Label; 209 public: 210 EncodeMapper(uint32 flags, EncodeType type) 211 : ref_count_(1), flags_(flags), type_(type), 212 table_(new EncodeTable<A>(flags)) {} 213 214 EncodeMapper(const EncodeMapper& mapper) 215 : ref_count_(mapper.ref_count_ + 1), 216 flags_(mapper.flags_), 217 type_(mapper.type_), 218 table_(mapper.table_) { } 219 220 // Copy constructor but setting the type, typically to DECODE 221 EncodeMapper(const EncodeMapper& mapper, EncodeType type) 222 : ref_count_(mapper.ref_count_ + 1), 223 flags_(mapper.flags_), 224 type_(type), 225 table_(mapper.table_) { } 226 227 ~EncodeMapper() { 228 if (--ref_count_ == 0) delete table_; 229 } 230 231 A operator()(const A &arc) { 232 if (type_ == ENCODE) { // labels and/or weights to single label 233 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || 234 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && 235 arc.weight == Weight::Zero())) { 236 return arc; 237 } else { 238 Label label = table_->Encode(arc); 239 return A(label, 240 flags_ & kEncodeLabels ? label : arc.olabel, 241 flags_ & kEncodeWeights ? Weight::One() : arc.weight, 242 arc.nextstate); 243 } 244 } else { 245 if (arc.nextstate == kNoStateId) { 246 return arc; 247 } else { 248 const typename EncodeTable<A>::Tuple* tuple = 249 table_->Decode(arc.ilabel); 250 return A(tuple->ilabel, 251 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, 252 flags_ & kEncodeWeights ? tuple->weight : arc.weight, 253 arc.nextstate);; 254 } 255 } 256 } 257 258 uint64 Properties(uint64 props) { 259 uint64 mask = kFstProperties; 260 if (flags_ & kEncodeLabels) 261 mask &= kILabelInvariantProperties & kOLabelInvariantProperties; 262 if (flags_ & kEncodeWeights) 263 mask &= kILabelInvariantProperties & kWeightInvariantProperties & 264 (type_ == ENCODE ? kAddSuperFinalProperties : 265 kRmSuperFinalProperties); 266 return props & mask; 267 } 268 269 270 MapFinalAction FinalAction() const { 271 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? 272 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; 273 } 274 275 uint32 flags() const { return flags_; } 276 EncodeType type() const { return type_; } 277 278 bool Write(ostream &strm, const string& source) { 279 return table_->Write(strm, source); 280 } 281 282 bool Write(const string& filename) { 283 ofstream strm(filename.c_str()); 284 if (!strm) { 285 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 286 return false; 287 } 288 return Write(strm, filename); 289 } 290 291 static EncodeMapper<A> *Read(istream &strm, 292 const string& source, EncodeType type) { 293 EncodeTable<A> *table = new EncodeTable<A>(0); 294 bool r = table->Read(strm, source); 295 return r ? new EncodeMapper(table->flags(), type, table) : 0; 296 } 297 298 static EncodeMapper<A> *Read(const string& filename, EncodeType type) { 299 ifstream strm(filename.c_str()); 300 if (!strm) { 301 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 302 return false; 303 } 304 return Read(strm, filename, type); 305 } 306 307 private: 308 uint32 ref_count_; 309 uint32 flags_; 310 EncodeType type_; 311 EncodeTable<A>* table_; 312 313 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table) 314 : ref_count_(1), flags_(flags), type_(type), table_(table) {} 315 void operator=(const EncodeMapper &); // Disallow. 316 }; 317 318 319 // Complexity: O(nstates + narcs) 320 template<class A> inline 321 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) { 322 Map(fst, mapper); 323 } 324 325 326 template<class A> inline 327 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) { 328 Map(fst, EncodeMapper<A>(mapper, DECODE)); 329 RmFinalEpsilon(fst); 330 } 331 332 333 // On the fly label and/or weight encoding of input Fst 334 // 335 // Complexity: 336 // - Constructor: O(1) 337 // - Traversal: O(nstates_visited + narcs_visited), assuming constant 338 // time to visit an input state or arc. 339 template <class A> 340 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > { 341 public: 342 typedef A Arc; 343 typedef EncodeMapper<A> C; 344 345 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder) 346 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {} 347 348 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 349 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {} 350 351 EncodeFst(const EncodeFst<A> &fst) 352 : MapFst<A, A, C>(fst) {} 353 354 virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); } 355 }; 356 357 358 // On the fly label and/or weight encoding of input Fst 359 // 360 // Complexity: 361 // - Constructor: O(1) 362 // - Traversal: O(nstates_visited + narcs_visited), assuming constant 363 // time to visit an input state or arc. 364 template <class A> 365 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > { 366 public: 367 typedef A Arc; 368 typedef EncodeMapper<A> C; 369 370 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 371 : MapFst<A, A, C>(fst, 372 EncodeMapper<A>(encoder, DECODE), 373 MapFstOptions()) {} 374 375 DecodeFst(const EncodeFst<A> &fst) 376 : MapFst<A, A, C>(fst) {} 377 378 virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); } 379 }; 380 381 382 // Specialization for EncodeFst. 383 template <class A> 384 class StateIterator< EncodeFst<A> > 385 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > { 386 public: 387 explicit StateIterator(const EncodeFst<A> &fst) 388 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {} 389 }; 390 391 392 // Specialization for EncodeFst. 393 template <class A> 394 class ArcIterator< EncodeFst<A> > 395 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > { 396 public: 397 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s) 398 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {} 399 }; 400 401 402 // Specialization for DecodeFst. 403 template <class A> 404 class StateIterator< DecodeFst<A> > 405 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > { 406 public: 407 explicit StateIterator(const DecodeFst<A> &fst) 408 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {} 409 }; 410 411 412 // Specialization for DecodeFst. 413 template <class A> 414 class ArcIterator< DecodeFst<A> > 415 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > { 416 public: 417 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s) 418 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {} 419 }; 420 421 422 // Useful aliases when using StdArc. 423 typedef EncodeFst<StdArc> StdEncodeFst; 424 425 typedef DecodeFst<StdArc> StdDecodeFst; 426 427 } 428 429 #endif // FST_LIB_ENCODE_H__ 430