1 // encode.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Class to encode and decoder an fst. 18 19 #ifndef FST_LIB_ENCODE_H__ 20 #define FST_LIB_ENCODE_H__ 21 22 #include "fst/lib/map.h" 23 #include "fst/lib/rmfinalepsilon.h" 24 25 namespace fst { 26 27 static const uint32 kEncodeLabels = 0x00001; 28 static const uint32 kEncodeWeights = 0x00002; 29 30 enum EncodeType { ENCODE = 1, DECODE = 2 }; 31 32 // Identifies stream data as an encode table (and its endianity) 33 static const int32 kEncodeMagicNumber = 2129983209; 34 35 36 // The following class encapsulates implementation details for the 37 // encoding and decoding of label/weight tuples used for encoding 38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it 39 // stores both the Tuple of encode labels and weights to a unique 40 // label, and the reverse. 41 template <class A> class EncodeTable { 42 public: 43 typedef typename A::Label Label; 44 typedef typename A::Weight Weight; 45 46 // Encoded data consists of arc input/output labels and arc weight 47 struct Tuple { 48 Tuple() {} 49 Tuple(Label ilabel_, Label olabel_, Weight weight_) 50 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} 51 Tuple(const Tuple& tuple) 52 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} 53 54 Label ilabel; 55 Label olabel; 56 Weight weight; 57 }; 58 59 // Comparison object for hashing EncodeTable Tuple(s). 60 class TupleEqual { 61 public: 62 bool operator()(const Tuple* x, const Tuple* y) const { 63 return (x->ilabel == y->ilabel && 64 x->olabel == y->olabel && 65 x->weight == y->weight); 66 } 67 }; 68 69 // Hash function for EncodeTabe Tuples. Based on the encode flags 70 // we either hash the labels, weights or compbination of them. 71 class TupleKey { 72 static const int kPrime = 7853; 73 public: 74 TupleKey() 75 : encode_flags_(kEncodeLabels | kEncodeWeights) {} 76 77 TupleKey(const TupleKey& key) 78 : encode_flags_(key.encode_flags_) {} 79 80 explicit TupleKey(uint32 encode_flags) 81 : encode_flags_(encode_flags) {} 82 83 size_t operator()(const Tuple* x) const { 84 int lshift = x->ilabel % kPrime; 85 int rshift = sizeof(size_t) - lshift; 86 size_t hash = x->ilabel << lshift; 87 if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift; 88 if (encode_flags_ & kEncodeWeights) hash ^= x->weight.Hash(); 89 return hash; 90 } 91 92 private: 93 int32 encode_flags_; 94 }; 95 96 typedef hash_map<const Tuple*, 97 Label, 98 TupleKey, 99 TupleEqual> EncodeHash; 100 101 explicit EncodeTable(uint32 encode_flags) 102 : flags_(encode_flags), 103 encode_hash_(1024, TupleKey(encode_flags)) {} 104 105 ~EncodeTable() { 106 for (size_t i = 0; i < encode_tuples_.size(); ++i) { 107 delete encode_tuples_[i]; 108 } 109 } 110 111 // Given an arc encode either input/ouptut labels or input/costs or both 112 Label Encode(const A &arc) { 113 const Tuple tuple(arc.ilabel, 114 flags_ & kEncodeLabels ? arc.olabel : 0, 115 flags_ & kEncodeWeights ? arc.weight : Weight::One()); 116 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); 117 if (it == encode_hash_.end()) { 118 encode_tuples_.push_back(new Tuple(tuple)); 119 encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); 120 return encode_tuples_.size(); 121 } else { 122 return it->second; 123 } 124 } 125 126 // Given an encode arc Label decode back to input/output labels and costs 127 const Tuple* Decode(Label key) { 128 return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0; 129 } 130 131 bool Write(ostream &strm, const string &source) const { 132 WriteType(strm, kEncodeMagicNumber); 133 WriteType(strm, flags_); 134 int64 size = encode_tuples_.size(); 135 WriteType(strm, size); 136 for (size_t i = 0; i < size; ++i) { 137 const Tuple* tuple = encode_tuples_[i]; 138 WriteType(strm, tuple->ilabel); 139 WriteType(strm, tuple->olabel); 140 tuple->weight.Write(strm); 141 } 142 strm.flush(); 143 if (!strm) 144 LOG(ERROR) << "EncodeTable::Write: write failed: " << source; 145 return strm; 146 } 147 148 bool Read(istream &strm, const string &source) { 149 encode_tuples_.clear(); 150 encode_hash_.clear(); 151 int32 magic_number = 0; 152 ReadType(strm, &magic_number); 153 if (magic_number != kEncodeMagicNumber) { 154 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; 155 return false; 156 } 157 ReadType(strm, &flags_); 158 int64 size; 159 ReadType(strm, &size); 160 if (!strm) { 161 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 162 return false; 163 } 164 for (size_t i = 0; i < size; ++i) { 165 Tuple* tuple = new Tuple(); 166 ReadType(strm, &tuple->ilabel); 167 ReadType(strm, &tuple->olabel); 168 tuple->weight.Read(strm); 169 encode_tuples_.push_back(tuple); 170 encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); 171 } 172 if (!strm) 173 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 174 return strm; 175 } 176 177 const uint32 flags() const { return flags_; } 178 private: 179 uint32 flags_; 180 vector<Tuple*> encode_tuples_; 181 EncodeHash encode_hash_; 182 183 DISALLOW_EVIL_CONSTRUCTORS(EncodeTable); 184 }; 185 186 187 // A mapper to encode/decode weighted transducers. Encoding of an 188 // Fst is useful for performing classical determinization or minimization 189 // on a weighted transducer by treating it as an unweighted acceptor over 190 // encoded labels. 191 // 192 // The Encode mapper stores the encoding in a local hash table (EncodeTable) 193 // This table is shared (and reference counted) between the encoder and 194 // decoder. A decoder has read only access to the EncodeTable. 195 // 196 // The EncodeMapper allows on the fly encoding of the machine. As the 197 // EncodeTable is generated the same table may by used to decode the machine 198 // on the fly. For example in the following sequence of operations 199 // 200 // Encode -> Determinize -> Decode 201 // 202 // we will use the encoding table generated during the encode step in the 203 // decode, even though the encoding is not complete. 204 // 205 template <class A> class EncodeMapper { 206 typedef typename A::Weight Weight; 207 typedef typename A::Label Label; 208 public: 209 EncodeMapper(uint32 flags, EncodeType type) 210 : ref_count_(1), flags_(flags), type_(type), 211 table_(new EncodeTable<A>(flags)) {} 212 213 EncodeMapper(const EncodeMapper& mapper) 214 : ref_count_(mapper.ref_count_ + 1), 215 flags_(mapper.flags_), 216 type_(mapper.type_), 217 table_(mapper.table_) { } 218 219 // Copy constructor but setting the type, typically to DECODE 220 EncodeMapper(const EncodeMapper& mapper, EncodeType type) 221 : ref_count_(mapper.ref_count_ + 1), 222 flags_(mapper.flags_), 223 type_(type), 224 table_(mapper.table_) { } 225 226 ~EncodeMapper() { 227 if (--ref_count_ == 0) delete table_; 228 } 229 230 A operator()(const A &arc) { 231 if (type_ == ENCODE) { // labels and/or weights to single label 232 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || 233 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && 234 arc.weight == Weight::Zero())) { 235 return arc; 236 } else { 237 Label label = table_->Encode(arc); 238 return A(label, 239 flags_ & kEncodeLabels ? label : arc.olabel, 240 flags_ & kEncodeWeights ? Weight::One() : arc.weight, 241 arc.nextstate); 242 } 243 } else { 244 if (arc.nextstate == kNoStateId) { 245 return arc; 246 } else { 247 const typename EncodeTable<A>::Tuple* tuple = 248 table_->Decode(arc.ilabel); 249 return A(tuple->ilabel, 250 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, 251 flags_ & kEncodeWeights ? tuple->weight : arc.weight, 252 arc.nextstate);; 253 } 254 } 255 } 256 257 uint64 Properties(uint64 props) { 258 uint64 mask = kFstProperties; 259 if (flags_ & kEncodeLabels) 260 mask &= kILabelInvariantProperties & kOLabelInvariantProperties; 261 if (flags_ & kEncodeWeights) 262 mask &= kILabelInvariantProperties & kWeightInvariantProperties & 263 (type_ == ENCODE ? kAddSuperFinalProperties : 264 kRmSuperFinalProperties); 265 return props & mask; 266 } 267 268 269 MapFinalAction FinalAction() const { 270 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? 271 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; 272 } 273 274 const uint32 flags() const { return flags_; } 275 const EncodeType type() const { return type_; } 276 277 bool Write(ostream &strm, const string& source) { 278 return table_->Write(strm, source); 279 } 280 281 bool Write(const string& filename) { 282 ofstream strm(filename.c_str()); 283 if (!strm) { 284 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 285 return false; 286 } 287 return Write(strm, filename); 288 } 289 290 static EncodeMapper<A> *Read(istream &strm, 291 const string& source, EncodeType type) { 292 EncodeTable<A> *table = new EncodeTable<A>(0); 293 bool r = table->Read(strm, source); 294 return r ? new EncodeMapper(table->flags(), type, table) : 0; 295 } 296 297 static EncodeMapper<A> *Read(const string& filename, EncodeType type) { 298 ifstream strm(filename.c_str()); 299 if (!strm) { 300 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 301 return false; 302 } 303 return Read(strm, filename, type); 304 } 305 306 private: 307 uint32 ref_count_; 308 uint32 flags_; 309 EncodeType type_; 310 EncodeTable<A>* table_; 311 312 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table) 313 : ref_count_(1), flags_(flags), type_(type), table_(table) {} 314 void operator=(const EncodeMapper &); // Disallow. 315 }; 316 317 318 // Complexity: O(nstates + narcs) 319 template<class A> inline 320 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) { 321 Map(fst, mapper); 322 } 323 324 325 template<class A> inline 326 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) { 327 Map(fst, EncodeMapper<A>(mapper, DECODE)); 328 RmFinalEpsilon(fst); 329 } 330 331 332 // On the fly label and/or weight encoding of input Fst 333 // 334 // Complexity: 335 // - Constructor: O(1) 336 // - Traversal: O(nstates_visited + narcs_visited), assuming constant 337 // time to visit an input state or arc. 338 template <class A> 339 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > { 340 public: 341 typedef A Arc; 342 typedef EncodeMapper<A> C; 343 344 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder) 345 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {} 346 347 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 348 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {} 349 350 EncodeFst(const EncodeFst<A> &fst) 351 : MapFst<A, A, C>(fst) {} 352 353 virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); } 354 }; 355 356 357 // On the fly label and/or weight encoding of input Fst 358 // 359 // Complexity: 360 // - Constructor: O(1) 361 // - Traversal: O(nstates_visited + narcs_visited), assuming constant 362 // time to visit an input state or arc. 363 template <class A> 364 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > { 365 public: 366 typedef A Arc; 367 typedef EncodeMapper<A> C; 368 369 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 370 : MapFst<A, A, C>(fst, 371 EncodeMapper<A>(encoder, DECODE), 372 MapFstOptions()) {} 373 374 DecodeFst(const EncodeFst<A> &fst) 375 : MapFst<A, A, C>(fst) {} 376 377 virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); } 378 }; 379 380 381 // Specialization for EncodeFst. 382 template <class A> 383 class StateIterator< EncodeFst<A> > 384 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > { 385 public: 386 explicit StateIterator(const EncodeFst<A> &fst) 387 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {} 388 }; 389 390 391 // Specialization for EncodeFst. 392 template <class A> 393 class ArcIterator< EncodeFst<A> > 394 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > { 395 public: 396 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s) 397 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {} 398 }; 399 400 401 // Specialization for DecodeFst. 402 template <class A> 403 class StateIterator< DecodeFst<A> > 404 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > { 405 public: 406 explicit StateIterator(const DecodeFst<A> &fst) 407 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {} 408 }; 409 410 411 // Specialization for DecodeFst. 412 template <class A> 413 class ArcIterator< DecodeFst<A> > 414 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > { 415 public: 416 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s) 417 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {} 418 }; 419 420 421 // Useful aliases when using StdArc. 422 typedef EncodeFst<StdArc> StdEncodeFst; 423 424 typedef DecodeFst<StdArc> StdDecodeFst; 425 426 } 427 428 #endif // FST_LIB_ENCODE_H__ 429