Home | History | Annotate | Download | only in lib
      1 // fst.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Finite-State Transducer (FST) - abstract base class definition,
     18 // state and arc iterator interface, and suggested base implementation.
     19 
     20 #ifndef FST_LIB_FST_H__
     21 #define FST_LIB_FST_H__
     22 
     23 #include "fst/lib/arc.h"
     24 #include "fst/lib/compat.h"
     25 #include "fst/lib/properties.h"
     26 #include "fst/lib/register.h"
     27 #include "fst/lib/symbol-table.h"
     28 #include "fst/lib/util.h"
     29 
     30 namespace fst {
     31 
     32 class FstHeader;
     33 template <class A> class StateIteratorData;
     34 template <class A> class ArcIteratorData;
     35 
     36 struct FstReadOptions  {
     37   string source;                // Where you're reading from
     38   const FstHeader *header;      // Pointer to Fst header (if non-zero)
     39   const SymbolTable* isymbols;  // Pointer to input symbols (if non-zero)
     40   const SymbolTable* osymbols;  // Pointer to output symbols (if non-zero)
     41 
     42   explicit FstReadOptions(const string& src = "<unspecified>",
     43                           const FstHeader *hdr = 0,
     44                           const SymbolTable* isym = 0,
     45                           const SymbolTable* osym = 0)
     46       : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
     47 };
     48 
     49 
     50 struct FstWriteOptions {
     51   string source;                    // Where you're writing to
     52   bool write_header;                // Write the header?
     53   bool write_isymbols;              // Write input symbols?
     54   bool write_osymbols;              // Write output symbols?
     55 
     56   explicit FstWriteOptions(const string& src = "<unspecifed>",
     57                            bool hdr = true, bool isym = true,
     58                            bool osym = true)
     59       : source(src), write_header(hdr),
     60         write_isymbols(isym),  write_osymbols(osym) {}
     61 };
     62 
     63 //
     64 // Fst HEADER CLASS
     65 //
     66 // This is the recommended Fst file header representation.
     67 //
     68 
     69 class FstHeader {
     70  public:
     71   enum {
     72     HAS_ISYMBOLS = 1,                           // Has input symbol table
     73     HAS_OSYMBOLS = 2                            // Has output symbol table
     74   } Flags;
     75 
     76   FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
     77                 numstates_(0), numarcs_(0) {}
     78   const string &FstType() const { return fsttype_; }
     79   const string &ArcType() const { return arctype_; }
     80   int32 Version() const { return version_; }
     81   int32 GetFlags() const { return flags_; }
     82   uint64 Properties() const { return properties_; }
     83   int64 Start() const { return start_; }
     84   int64 NumStates() const { return numstates_; }
     85   int64 NumArcs() const { return numarcs_; }
     86 
     87   void SetFstType(const string& type) { fsttype_ = type; }
     88   void SetArcType(const string& type) { arctype_ = type; }
     89   void SetVersion(int32 version) { version_ = version; }
     90   void SetFlags(int32 flags) { flags_ = flags; }
     91   void SetProperties(uint64 properties) { properties_ = properties; }
     92   void SetStart(int64 start) { start_ = start; }
     93   void SetNumStates(int64 numstates) { numstates_ = numstates; }
     94   void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
     95 
     96   bool Read(istream &strm, const string &source);
     97   bool Write(ostream &strm, const string &source) const;
     98 
     99  private:
    100   string fsttype_;                   // E.g. "vector"
    101   string arctype_;                   // E.g. "standard"
    102   int32 version_;                    // Type version #
    103   int32 flags_;                      // File format bits
    104   uint64 properties_;                // FST property bits
    105   int64 start_;                      // Start state
    106   int64 numstates_;                  // # of states
    107   int64 numarcs_;                    // # of arcs
    108 };
    109 
    110 //
    111 // Fst INTERFACE CLASS DEFINITION
    112 //
    113 
    114 // A generic FST, templated on the arc definition, with
    115 // common-demoninator methods (use StateIterator and ArcIterator to
    116 // iterate over its states and arcs).
    117 template <class A>
    118 class Fst {
    119  public:
    120   typedef A Arc;
    121   typedef typename A::Weight Weight;
    122   typedef typename A::StateId StateId;
    123 
    124   virtual ~Fst() {}
    125 
    126   virtual StateId Start() const = 0;          // Initial state
    127 
    128   virtual Weight Final(StateId) const = 0;    // State's final weight
    129 
    130   virtual size_t NumArcs(StateId) const = 0;  // State's arc count
    131 
    132   virtual size_t NumInputEpsilons(StateId)
    133       const = 0;                              // State's input epsilon count
    134 
    135   virtual size_t NumOutputEpsilons(StateId)
    136       const = 0;                              // State's output epsilon count
    137 
    138   // If test=false, return stored properties bits for mask (some poss. unknown)
    139   // If test=true, return property bits for mask (computing o.w. unknown)
    140   virtual uint64 Properties(uint64 mask, bool test)
    141       const = 0;  // Property bits
    142 
    143   virtual const string& Type() const = 0;    // Fst type name
    144 
    145   // Get a copy of this Fst.
    146   virtual Fst<A> *Copy() const = 0;
    147   // Read an Fst from an input stream; returns NULL on error
    148 
    149   static Fst<A> *Read(istream &strm, const FstReadOptions &opts) {
    150     FstReadOptions ropts(opts);
    151     FstHeader hdr;
    152     if (ropts.header)
    153       hdr = *opts.header;
    154     else {
    155       if (!hdr.Read(strm, opts.source))
    156         return 0;
    157       ropts.header = &hdr;
    158     }
    159     FstRegister<A> *registr = FstRegister<A>::GetRegister();
    160     const typename FstRegister<A>::Reader reader =
    161         registr->GetReader(hdr.FstType());
    162     if (!reader) {
    163       LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType()
    164                  << "\" (arc type = \"" << A::Type()
    165                  << "\"): " << ropts.source;
    166       return 0;
    167     }
    168     return reader(strm, ropts);
    169   };
    170 
    171   // Read an Fst from a file; return NULL on error
    172   static Fst<A> *Read(const string &filename) {
    173     ifstream strm(filename.c_str());
    174     if (!strm) {
    175       LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
    176       return 0;
    177     }
    178     return Read(strm, FstReadOptions(filename));
    179   }
    180 
    181   // Write an Fst to an output stream; return false on error
    182   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
    183     LOG(ERROR) << "Fst::Write: No write method for " << Type() << " Fst type";
    184     return false;
    185   }
    186 
    187   // Write an Fst to a file; return false on error
    188   virtual bool Write(const string &filename) const {
    189     LOG(ERROR) << "Fst::Write: No write method for "
    190                << Type() << " Fst type: "
    191                << (filename.empty() ? "standard output" : filename);
    192     return false;
    193   }
    194 
    195   // Return input label symbol table; return NULL if not specified
    196   virtual const SymbolTable* InputSymbols() const = 0;
    197 
    198   // Return output label symbol table; return NULL if not specified
    199   virtual const SymbolTable* OutputSymbols() const = 0;
    200 
    201   // For generic state iterator construction; not normally called
    202   // directly by users.
    203   virtual void InitStateIterator(StateIteratorData<A> *) const = 0;
    204 
    205   // For generic arc iterator construction; not normally called
    206   // directly by users.
    207   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0;
    208 };
    209 
    210 
    211 //
    212 // STATE and ARC ITERATOR DEFINITIONS
    213 //
    214 
    215 // State iterator interface templated on the Arc definition; used
    216 // for StateIterator specializations returned by InitStateIterator.
    217 template <class A>
    218 class StateIteratorBase {
    219  public:
    220   typedef A Arc;
    221   typedef typename A::StateId StateId;
    222 
    223   virtual ~StateIteratorBase() {}
    224   virtual bool Done() const = 0;      // End of iterator?
    225   virtual StateId Value() const = 0;  // Current state (when !Done)
    226   virtual void Next() = 0;            // Advance to next state (when !Done)
    227   virtual void Reset() = 0;           // Return to initial condition
    228 };
    229 
    230 
    231 // StateIterator initialization data
    232 template <class A> struct StateIteratorData {
    233   StateIteratorBase<A> *base;   // Specialized iterator if non-zero
    234   typename A::StateId nstates;  // O.w. total # of states
    235 };
    236 
    237 
    238 // Generic state iterator, templated on the FST definition
    239 // - a wrapper around pointer to specific one.
    240 // Here is a typical use: \code
    241 //   for (StateIterator<StdFst> siter(fst);
    242 //        !siter.Done();
    243 //        siter.Next()) {
    244 //     StateId s = siter.Value();
    245 //     ...
    246 //   } \endcode
    247 template <class F>
    248 class StateIterator {
    249  public:
    250   typedef typename F::Arc Arc;
    251   typedef typename Arc::StateId StateId;
    252 
    253   explicit StateIterator(const F &fst) : s_(0) {
    254     fst.InitStateIterator(&data_);
    255   }
    256 
    257   ~StateIterator() { if (data_.base) delete data_.base; }
    258 
    259   bool Done() const {
    260     return data_.base ? data_.base->Done() : s_ >= data_.nstates;
    261   }
    262 
    263   StateId Value() const { return data_.base ? data_.base->Value() : s_; }
    264 
    265   void Next() {
    266     if (data_.base)
    267       data_.base->Next();
    268     else
    269       ++s_;
    270   }
    271 
    272   void Reset() {
    273     if (data_.base)
    274       data_.base->Reset();
    275     else
    276       s_ = 0;
    277   }
    278 
    279  private:
    280   StateIteratorData<Arc> data_;
    281   StateId s_;
    282   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    283 };
    284 
    285 
    286 // Arc iterator interface, templated on the Arc definition; used
    287 // for Arc iterator specializations that are returned by InitArcIterator.
    288 template <class A>
    289 class ArcIteratorBase {
    290  public:
    291   typedef A Arc;
    292   typedef typename A::StateId StateId;
    293 
    294   virtual ~ArcIteratorBase() {}
    295   virtual bool Done() const = 0;       // End of iterator?
    296   virtual const A& Value() const = 0;  // Current state (when !Done)
    297   virtual void Next() = 0;             // Advance to next arc (when !Done)
    298   virtual void Reset() = 0;            // Return to initial condition
    299   virtual void Seek(size_t a) = 0;     // Random arc access by position
    300 };
    301 
    302 
    303 // ArcIterator initialization data
    304 template <class A> struct ArcIteratorData {
    305   ArcIteratorBase<A> *base;  // Specialized iterator if non-zero
    306   const A *arcs;             // O.w. arcs pointer
    307   size_t narcs;              // ... and arc count
    308   int *ref_count;            // ... and reference count if non-zero
    309 };
    310 
    311 
    312 // Generic arc iterator, templated on the FST definition
    313 // - a wrapper around pointer to specific one.
    314 // Here is a typical use: \code
    315 //   for (ArcIterator<StdFst> aiter(fst, s));
    316 //        !aiter.Done();
    317 //         aiter.Next()) {
    318 //     StdArc &arc = aiter.Value();
    319 //     ...
    320 //   } \endcode
    321 template <class F>
    322 class ArcIterator {
    323    public:
    324   typedef typename F::Arc Arc;
    325   typedef typename Arc::StateId StateId;
    326 
    327   ArcIterator(const F &fst, StateId s) : i_(0) {
    328     fst.InitArcIterator(s, &data_);
    329   }
    330 
    331   ~ArcIterator() {
    332     if (data_.base)
    333       delete data_.base;
    334     else if (data_.ref_count)
    335     --(*data_.ref_count);
    336   }
    337 
    338   bool Done() const {
    339     return data_.base ?  data_.base->Done() : i_ >= data_.narcs;
    340   }
    341 
    342   const Arc& Value() const {
    343     return data_.base ? data_.base->Value() : data_.arcs[i_];
    344   }
    345 
    346   void Next() {
    347     if (data_.base)
    348       data_.base->Next();
    349     else
    350       ++i_;
    351   }
    352 
    353   void Reset() {
    354     if (data_.base)
    355       data_.base->Reset();
    356     else
    357       i_ = 0;
    358   }
    359 
    360   void Seek(size_t a) {
    361     if (data_.base)
    362       data_.base->Seek(a);
    363     else
    364       i_ = a;
    365   }
    366 
    367  private:
    368   ArcIteratorData<Arc> data_;
    369   size_t i_;
    370   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    371 };
    372 
    373 
    374 // A useful alias when using StdArc.
    375 typedef Fst<StdArc> StdFst;
    376 
    377 
    378 //
    379 //  CONSTANT DEFINITIONS
    380 //
    381 
    382 const int kNoStateId   =  -1;  // Not a valid state ID
    383 const int kNoLabel     =  -1;  // Not a valid label
    384 const int kPhiLabel    =  -2;  // Failure transition label
    385 const int kRhoLabel    =  -3;  // Matches o.w. unmatched labels (lib. internal)
    386 const int kSigmaLabel  =  -4;  // Matches all labels in alphabet.
    387 
    388 
    389 //
    390 // Fst IMPLEMENTATION BASE
    391 //
    392 // This is the recommended Fst implementation base class. It will
    393 // handle reference counts, property bits, type information and symbols.
    394 //
    395 
    396 template <class A> class FstImpl {
    397  public:
    398   typedef typename A::Weight Weight;
    399   typedef typename A::StateId StateId;
    400 
    401   FstImpl()
    402       : properties_(0), type_("null"), isymbols_(0), osymbols_(0),
    403         ref_count_(1) {}
    404 
    405   FstImpl(const FstImpl<A> &impl)
    406       : properties_(impl.properties_), type_(impl.type_),
    407         isymbols_(impl.isymbols_ ? new SymbolTable(impl.isymbols_) : 0),
    408         osymbols_(impl.osymbols_ ? new SymbolTable(impl.osymbols_) : 0),
    409         ref_count_(1) {}
    410 
    411   ~FstImpl() {
    412     delete isymbols_;
    413     delete osymbols_;
    414   }
    415 
    416   const string& Type() const { return type_; }
    417 
    418   void SetType(const string &type) { type_ = type; }
    419 
    420   uint64 Properties() const { return properties_; }
    421 
    422   uint64 Properties(uint64 mask) const { return properties_ & mask; }
    423 
    424   void SetProperties(uint64 props) { properties_ = props; }
    425 
    426   void SetProperties(uint64 props, uint64 mask) {
    427     properties_ &= ~mask;
    428     properties_ |= props & mask;
    429   }
    430 
    431   const SymbolTable* InputSymbols() const { return isymbols_; }
    432 
    433   const SymbolTable* OutputSymbols() const { return osymbols_; }
    434 
    435   SymbolTable* InputSymbols() { return isymbols_; }
    436 
    437   SymbolTable* OutputSymbols() { return osymbols_; }
    438 
    439   void SetInputSymbols(const SymbolTable* isyms) {
    440     if (isymbols_) delete isymbols_;
    441     isymbols_ = isyms ? isyms->Copy() : 0;
    442   }
    443 
    444   void SetOutputSymbols(const SymbolTable* osyms) {
    445     if (osymbols_) delete osymbols_;
    446     osymbols_ = osyms ? osyms->Copy() : 0;
    447   }
    448 
    449   int RefCount() const { return ref_count_; }
    450 
    451   int IncrRefCount() { return ++ref_count_; }
    452 
    453   int DecrRefCount() { return --ref_count_; }
    454 
    455   // Read-in header and symbols, initialize Fst, and return the header.
    456   // If opts.header is non-null, skip read-in and use the option value.
    457   // If opts.[io]symbols is non-null, read-in but use the option value.
    458   bool ReadHeaderAndSymbols(istream &strm, const FstReadOptions& opts,
    459                   int min_version, FstHeader *hdr) {
    460     if (opts.header)
    461       *hdr = *opts.header;
    462     else if (!hdr->Read(strm, opts.source))
    463       return false;
    464     if (hdr->FstType() != type_) {
    465       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Fst not of type \""
    466                  << type_ << "\": " << opts.source;
    467       return false;
    468     }
    469     if (hdr->ArcType() != A::Type()) {
    470       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Arc not of type \""
    471                  << A::Type()
    472                  << "\": " << opts.source;
    473       return false;
    474     }
    475     if (hdr->Version() < min_version) {
    476       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Obsolete "
    477                  << type_ << " Fst version: " << opts.source;
    478       return false;
    479     }
    480     properties_ = hdr->Properties();
    481     if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS)
    482       isymbols_ = SymbolTable::Read(strm, opts.source);
    483     if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS)
    484       osymbols_ =SymbolTable::Read(strm, opts.source);
    485 
    486     if (opts.isymbols) {
    487       delete isymbols_;
    488       isymbols_ = opts.isymbols->Copy();
    489     }
    490     if (opts.osymbols) {
    491       delete osymbols_;
    492       osymbols_ = opts.osymbols->Copy();
    493     }
    494     return true;
    495   }
    496 
    497   // Write-out header and symbols.
    498   // If a opts.header is false, skip writing header.
    499   // If opts.[io]symbols is false, skip writing those symbols.
    500   void WriteHeaderAndSymbols(ostream &strm, const FstWriteOptions& opts,
    501                              int version, FstHeader *hdr) const {
    502     if (opts.write_header) {
    503       hdr->SetFstType(type_);
    504       hdr->SetArcType(A::Type());
    505       hdr->SetVersion(version);
    506       hdr->SetProperties(properties_);
    507       int32 file_flags = 0;
    508       if (isymbols_ && opts.write_isymbols)
    509         file_flags |= FstHeader::HAS_ISYMBOLS;
    510       if (osymbols_ && opts.write_osymbols)
    511         file_flags |= FstHeader::HAS_OSYMBOLS;
    512       hdr->SetFlags(file_flags);
    513       hdr->Write(strm, opts.source);
    514     }
    515     if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
    516     if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
    517   }
    518 
    519  protected:
    520   uint64 properties_;           // Property bits
    521 
    522  private:
    523   string type_;                 // Unique name of Fst class
    524   SymbolTable *isymbols_;       // Ilabel symbol table
    525   SymbolTable *osymbols_;       // Olabel symbol table
    526   int ref_count_;               // Reference count
    527 
    528   void operator=(const FstImpl<A> &impl);  // disallow
    529 };
    530 
    531 }  // namespace fst;
    532 
    533 #endif  // FST_LIB_FST_H__
    534