1 // fst.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Finite-State Transducer (FST) - abstract base class definition, 18 // state and arc iterator interface, and suggested base implementation. 19 20 #ifndef FST_LIB_FST_H__ 21 #define FST_LIB_FST_H__ 22 23 #include "fst/lib/arc.h" 24 #include "fst/lib/compat.h" 25 #include "fst/lib/properties.h" 26 #include "fst/lib/register.h" 27 #include "fst/lib/symbol-table.h" 28 #include "fst/lib/util.h" 29 30 namespace fst { 31 32 class FstHeader; 33 template <class A> class StateIteratorData; 34 template <class A> class ArcIteratorData; 35 36 struct FstReadOptions { 37 string source; // Where you're reading from 38 const FstHeader *header; // Pointer to Fst header (if non-zero) 39 const SymbolTable* isymbols; // Pointer to input symbols (if non-zero) 40 const SymbolTable* osymbols; // Pointer to output symbols (if non-zero) 41 42 explicit FstReadOptions(const string& src = "<unspecified>", 43 const FstHeader *hdr = 0, 44 const SymbolTable* isym = 0, 45 const SymbolTable* osym = 0) 46 : source(src), header(hdr), isymbols(isym), osymbols(osym) {} 47 }; 48 49 50 struct FstWriteOptions { 51 string source; // Where you're writing to 52 bool write_header; // Write the header? 53 bool write_isymbols; // Write input symbols? 54 bool write_osymbols; // Write output symbols? 55 56 explicit FstWriteOptions(const string& src = "<unspecifed>", 57 bool hdr = true, bool isym = true, 58 bool osym = true) 59 : source(src), write_header(hdr), 60 write_isymbols(isym), write_osymbols(osym) {} 61 }; 62 63 // 64 // Fst HEADER CLASS 65 // 66 // This is the recommended Fst file header representation. 67 // 68 69 class FstHeader { 70 public: 71 enum { 72 HAS_ISYMBOLS = 1, // Has input symbol table 73 HAS_OSYMBOLS = 2 // Has output symbol table 74 } Flags; 75 76 FstHeader() : version_(0), flags_(0), properties_(0), start_(-1), 77 numstates_(0), numarcs_(0) {} 78 const string &FstType() const { return fsttype_; } 79 const string &ArcType() const { return arctype_; } 80 int32 Version() const { return version_; } 81 int32 GetFlags() const { return flags_; } 82 uint64 Properties() const { return properties_; } 83 int64 Start() const { return start_; } 84 int64 NumStates() const { return numstates_; } 85 int64 NumArcs() const { return numarcs_; } 86 87 void SetFstType(const string& type) { fsttype_ = type; } 88 void SetArcType(const string& type) { arctype_ = type; } 89 void SetVersion(int32 version) { version_ = version; } 90 void SetFlags(int32 flags) { flags_ = flags; } 91 void SetProperties(uint64 properties) { properties_ = properties; } 92 void SetStart(int64 start) { start_ = start; } 93 void SetNumStates(int64 numstates) { numstates_ = numstates; } 94 void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; } 95 96 bool Read(istream &strm, const string &source); 97 bool Write(ostream &strm, const string &source) const; 98 99 private: 100 string fsttype_; // E.g. "vector" 101 string arctype_; // E.g. "standard" 102 int32 version_; // Type version # 103 int32 flags_; // File format bits 104 uint64 properties_; // FST property bits 105 int64 start_; // Start state 106 int64 numstates_; // # of states 107 int64 numarcs_; // # of arcs 108 }; 109 110 // 111 // Fst INTERFACE CLASS DEFINITION 112 // 113 114 // A generic FST, templated on the arc definition, with 115 // common-demoninator methods (use StateIterator and ArcIterator to 116 // iterate over its states and arcs). 117 template <class A> 118 class Fst { 119 public: 120 typedef A Arc; 121 typedef typename A::Weight Weight; 122 typedef typename A::StateId StateId; 123 124 virtual ~Fst() {} 125 126 virtual StateId Start() const = 0; // Initial state 127 128 virtual Weight Final(StateId) const = 0; // State's final weight 129 130 virtual size_t NumArcs(StateId) const = 0; // State's arc count 131 132 virtual size_t NumInputEpsilons(StateId) 133 const = 0; // State's input epsilon count 134 135 virtual size_t NumOutputEpsilons(StateId) 136 const = 0; // State's output epsilon count 137 138 // If test=false, return stored properties bits for mask (some poss. unknown) 139 // If test=true, return property bits for mask (computing o.w. unknown) 140 virtual uint64 Properties(uint64 mask, bool test) 141 const = 0; // Property bits 142 143 virtual const string& Type() const = 0; // Fst type name 144 145 // Get a copy of this Fst. 146 virtual Fst<A> *Copy() const = 0; 147 // Read an Fst from an input stream; returns NULL on error 148 149 static Fst<A> *Read(istream &strm, const FstReadOptions &opts) { 150 FstReadOptions ropts(opts); 151 FstHeader hdr; 152 if (ropts.header) 153 hdr = *opts.header; 154 else { 155 if (!hdr.Read(strm, opts.source)) 156 return 0; 157 ropts.header = &hdr; 158 } 159 FstRegister<A> *registr = FstRegister<A>::GetRegister(); 160 const typename FstRegister<A>::Reader reader = 161 registr->GetReader(hdr.FstType()); 162 if (!reader) { 163 LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType() 164 << "\" (arc type = \"" << A::Type() 165 << "\"): " << ropts.source; 166 return 0; 167 } 168 return reader(strm, ropts); 169 }; 170 171 // Read an Fst from a file; return NULL on error 172 static Fst<A> *Read(const string &filename) { 173 ifstream strm(filename.c_str()); 174 if (!strm) { 175 LOG(ERROR) << "Fst::Read: Can't open file: " << filename; 176 return 0; 177 } 178 return Read(strm, FstReadOptions(filename)); 179 } 180 181 // Write an Fst to an output stream; return false on error 182 virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { 183 LOG(ERROR) << "Fst::Write: No write method for " << Type() << " Fst type"; 184 return false; 185 } 186 187 // Write an Fst to a file; return false on error 188 virtual bool Write(const string &filename) const { 189 LOG(ERROR) << "Fst::Write: No write method for " 190 << Type() << " Fst type: " 191 << (filename.empty() ? "standard output" : filename); 192 return false; 193 } 194 195 // Return input label symbol table; return NULL if not specified 196 virtual const SymbolTable* InputSymbols() const = 0; 197 198 // Return output label symbol table; return NULL if not specified 199 virtual const SymbolTable* OutputSymbols() const = 0; 200 201 // For generic state iterator construction; not normally called 202 // directly by users. 203 virtual void InitStateIterator(StateIteratorData<A> *) const = 0; 204 205 // For generic arc iterator construction; not normally called 206 // directly by users. 207 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0; 208 }; 209 210 211 // 212 // STATE and ARC ITERATOR DEFINITIONS 213 // 214 215 // State iterator interface templated on the Arc definition; used 216 // for StateIterator specializations returned by InitStateIterator. 217 template <class A> 218 class StateIteratorBase { 219 public: 220 typedef A Arc; 221 typedef typename A::StateId StateId; 222 223 virtual ~StateIteratorBase() {} 224 virtual bool Done() const = 0; // End of iterator? 225 virtual StateId Value() const = 0; // Current state (when !Done) 226 virtual void Next() = 0; // Advance to next state (when !Done) 227 virtual void Reset() = 0; // Return to initial condition 228 }; 229 230 231 // StateIterator initialization data 232 template <class A> struct StateIteratorData { 233 StateIteratorBase<A> *base; // Specialized iterator if non-zero 234 typename A::StateId nstates; // O.w. total # of states 235 }; 236 237 238 // Generic state iterator, templated on the FST definition 239 // - a wrapper around pointer to specific one. 240 // Here is a typical use: \code 241 // for (StateIterator<StdFst> siter(fst); 242 // !siter.Done(); 243 // siter.Next()) { 244 // StateId s = siter.Value(); 245 // ... 246 // } \endcode 247 template <class F> 248 class StateIterator { 249 public: 250 typedef typename F::Arc Arc; 251 typedef typename Arc::StateId StateId; 252 253 explicit StateIterator(const F &fst) : s_(0) { 254 fst.InitStateIterator(&data_); 255 } 256 257 ~StateIterator() { if (data_.base) delete data_.base; } 258 259 bool Done() const { 260 return data_.base ? data_.base->Done() : s_ >= data_.nstates; 261 } 262 263 StateId Value() const { return data_.base ? data_.base->Value() : s_; } 264 265 void Next() { 266 if (data_.base) 267 data_.base->Next(); 268 else 269 ++s_; 270 } 271 272 void Reset() { 273 if (data_.base) 274 data_.base->Reset(); 275 else 276 s_ = 0; 277 } 278 279 private: 280 StateIteratorData<Arc> data_; 281 StateId s_; 282 DISALLOW_EVIL_CONSTRUCTORS(StateIterator); 283 }; 284 285 286 // Arc iterator interface, templated on the Arc definition; used 287 // for Arc iterator specializations that are returned by InitArcIterator. 288 template <class A> 289 class ArcIteratorBase { 290 public: 291 typedef A Arc; 292 typedef typename A::StateId StateId; 293 294 virtual ~ArcIteratorBase() {} 295 virtual bool Done() const = 0; // End of iterator? 296 virtual const A& Value() const = 0; // Current state (when !Done) 297 virtual void Next() = 0; // Advance to next arc (when !Done) 298 virtual void Reset() = 0; // Return to initial condition 299 virtual void Seek(size_t a) = 0; // Random arc access by position 300 }; 301 302 303 // ArcIterator initialization data 304 template <class A> struct ArcIteratorData { 305 ArcIteratorBase<A> *base; // Specialized iterator if non-zero 306 const A *arcs; // O.w. arcs pointer 307 size_t narcs; // ... and arc count 308 int *ref_count; // ... and reference count if non-zero 309 }; 310 311 312 // Generic arc iterator, templated on the FST definition 313 // - a wrapper around pointer to specific one. 314 // Here is a typical use: \code 315 // for (ArcIterator<StdFst> aiter(fst, s)); 316 // !aiter.Done(); 317 // aiter.Next()) { 318 // StdArc &arc = aiter.Value(); 319 // ... 320 // } \endcode 321 template <class F> 322 class ArcIterator { 323 public: 324 typedef typename F::Arc Arc; 325 typedef typename Arc::StateId StateId; 326 327 ArcIterator(const F &fst, StateId s) : i_(0) { 328 fst.InitArcIterator(s, &data_); 329 } 330 331 ~ArcIterator() { 332 if (data_.base) 333 delete data_.base; 334 else if (data_.ref_count) 335 --(*data_.ref_count); 336 } 337 338 bool Done() const { 339 return data_.base ? data_.base->Done() : i_ >= data_.narcs; 340 } 341 342 const Arc& Value() const { 343 return data_.base ? data_.base->Value() : data_.arcs[i_]; 344 } 345 346 void Next() { 347 if (data_.base) 348 data_.base->Next(); 349 else 350 ++i_; 351 } 352 353 void Reset() { 354 if (data_.base) 355 data_.base->Reset(); 356 else 357 i_ = 0; 358 } 359 360 void Seek(size_t a) { 361 if (data_.base) 362 data_.base->Seek(a); 363 else 364 i_ = a; 365 } 366 367 private: 368 ArcIteratorData<Arc> data_; 369 size_t i_; 370 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 371 }; 372 373 374 // A useful alias when using StdArc. 375 typedef Fst<StdArc> StdFst; 376 377 378 // 379 // CONSTANT DEFINITIONS 380 // 381 382 const int kNoStateId = -1; // Not a valid state ID 383 const int kNoLabel = -1; // Not a valid label 384 const int kPhiLabel = -2; // Failure transition label 385 const int kRhoLabel = -3; // Matches o.w. unmatched labels (lib. internal) 386 const int kSigmaLabel = -4; // Matches all labels in alphabet. 387 388 389 // 390 // Fst IMPLEMENTATION BASE 391 // 392 // This is the recommended Fst implementation base class. It will 393 // handle reference counts, property bits, type information and symbols. 394 // 395 396 template <class A> class FstImpl { 397 public: 398 typedef typename A::Weight Weight; 399 typedef typename A::StateId StateId; 400 401 FstImpl() 402 : properties_(0), type_("null"), isymbols_(0), osymbols_(0), 403 ref_count_(1) {} 404 405 FstImpl(const FstImpl<A> &impl) 406 : properties_(impl.properties_), type_(impl.type_), 407 isymbols_(impl.isymbols_ ? new SymbolTable(impl.isymbols_) : 0), 408 osymbols_(impl.osymbols_ ? new SymbolTable(impl.osymbols_) : 0), 409 ref_count_(1) {} 410 411 ~FstImpl() { 412 delete isymbols_; 413 delete osymbols_; 414 } 415 416 const string& Type() const { return type_; } 417 418 void SetType(const string &type) { type_ = type; } 419 420 uint64 Properties() const { return properties_; } 421 422 uint64 Properties(uint64 mask) const { return properties_ & mask; } 423 424 void SetProperties(uint64 props) { properties_ = props; } 425 426 void SetProperties(uint64 props, uint64 mask) { 427 properties_ &= ~mask; 428 properties_ |= props & mask; 429 } 430 431 const SymbolTable* InputSymbols() const { return isymbols_; } 432 433 const SymbolTable* OutputSymbols() const { return osymbols_; } 434 435 SymbolTable* InputSymbols() { return isymbols_; } 436 437 SymbolTable* OutputSymbols() { return osymbols_; } 438 439 void SetInputSymbols(const SymbolTable* isyms) { 440 if (isymbols_) delete isymbols_; 441 isymbols_ = isyms ? isyms->Copy() : 0; 442 } 443 444 void SetOutputSymbols(const SymbolTable* osyms) { 445 if (osymbols_) delete osymbols_; 446 osymbols_ = osyms ? osyms->Copy() : 0; 447 } 448 449 int RefCount() const { return ref_count_; } 450 451 int IncrRefCount() { return ++ref_count_; } 452 453 int DecrRefCount() { return --ref_count_; } 454 455 // Read-in header and symbols, initialize Fst, and return the header. 456 // If opts.header is non-null, skip read-in and use the option value. 457 // If opts.[io]symbols is non-null, read-in but use the option value. 458 bool ReadHeaderAndSymbols(istream &strm, const FstReadOptions& opts, 459 int min_version, FstHeader *hdr) { 460 if (opts.header) 461 *hdr = *opts.header; 462 else if (!hdr->Read(strm, opts.source)) 463 return false; 464 if (hdr->FstType() != type_) { 465 LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Fst not of type \"" 466 << type_ << "\": " << opts.source; 467 return false; 468 } 469 if (hdr->ArcType() != A::Type()) { 470 LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Arc not of type \"" 471 << A::Type() 472 << "\": " << opts.source; 473 return false; 474 } 475 if (hdr->Version() < min_version) { 476 LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Obsolete " 477 << type_ << " Fst version: " << opts.source; 478 return false; 479 } 480 properties_ = hdr->Properties(); 481 if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) 482 isymbols_ = SymbolTable::Read(strm, opts.source); 483 if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) 484 osymbols_ =SymbolTable::Read(strm, opts.source); 485 486 if (opts.isymbols) { 487 delete isymbols_; 488 isymbols_ = opts.isymbols->Copy(); 489 } 490 if (opts.osymbols) { 491 delete osymbols_; 492 osymbols_ = opts.osymbols->Copy(); 493 } 494 return true; 495 } 496 497 // Write-out header and symbols. 498 // If a opts.header is false, skip writing header. 499 // If opts.[io]symbols is false, skip writing those symbols. 500 void WriteHeaderAndSymbols(ostream &strm, const FstWriteOptions& opts, 501 int version, FstHeader *hdr) const { 502 if (opts.write_header) { 503 hdr->SetFstType(type_); 504 hdr->SetArcType(A::Type()); 505 hdr->SetVersion(version); 506 hdr->SetProperties(properties_); 507 int32 file_flags = 0; 508 if (isymbols_ && opts.write_isymbols) 509 file_flags |= FstHeader::HAS_ISYMBOLS; 510 if (osymbols_ && opts.write_osymbols) 511 file_flags |= FstHeader::HAS_OSYMBOLS; 512 hdr->SetFlags(file_flags); 513 hdr->Write(strm, opts.source); 514 } 515 if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm); 516 if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm); 517 } 518 519 protected: 520 uint64 properties_; // Property bits 521 522 private: 523 string type_; // Unique name of Fst class 524 SymbolTable *isymbols_; // Ilabel symbol table 525 SymbolTable *osymbols_; // Olabel symbol table 526 int ref_count_; // Reference count 527 528 void operator=(const FstImpl<A> &impl); // disallow 529 }; 530 531 } // namespace fst; 532 533 #endif // FST_LIB_FST_H__ 534