Home | History | Annotate | Download | only in lib
      1 // const-fst.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Simple concrete immutable FST whose states and arcs are each stored
     18 // in single arrays.
     19 
     20 #ifndef FST_LIB_CONST_FST_H__
     21 #define FST_LIB_CONST_FST_H__
     22 
     23 #include "fst/lib/expanded-fst.h"
     24 #include "fst/lib/test-properties.h"
     25 
     26 namespace fst {
     27 
     28 template <class A> class ConstFst;
     29 
     30 // States and arcs each implemented by single arrays, templated on the
     31 // Arc definition.
     32 template <class A>
     33 class ConstFstImpl : public FstImpl<A> {
     34  public:
     35   using FstImpl<A>::SetType;
     36   using FstImpl<A>::SetProperties;
     37   using FstImpl<A>::Properties;
     38   using FstImpl<A>::WriteHeaderAndSymbols;
     39 
     40   typedef typename A::Weight Weight;
     41   typedef typename A::StateId StateId;
     42 
     43   ConstFstImpl()
     44       : states_(0), arcs_(0), nstates_(0), narcs_(0), start_(kNoStateId) {
     45     SetType("const");
     46     SetProperties(kNullProperties | kStaticProperties);
     47   }
     48 
     49   explicit ConstFstImpl(const Fst<A> &fst);
     50 
     51   ~ConstFstImpl() {
     52     delete[] states_;
     53     delete[] arcs_;
     54   }
     55 
     56   StateId Start() const { return start_; }
     57 
     58   Weight Final(StateId s) const { return states_[s].final; }
     59 
     60   StateId NumStates() const { return nstates_; }
     61 
     62   size_t NumArcs(StateId s) const { return states_[s].narcs; }
     63 
     64   size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; }
     65 
     66   size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; }
     67 
     68   static ConstFstImpl<A> *Read(istream &strm, const FstReadOptions &opts);
     69 
     70   bool Write(ostream &strm, const FstWriteOptions &opts) const;
     71 
     72   A *Arcs(StateId s) { return arcs_ + states_[s].pos; }
     73 
     74   // Provide information needed for generic state iterator
     75   void InitStateIterator(StateIteratorData<A> *data) const {
     76     data->base = 0;
     77     data->nstates = nstates_;
     78   }
     79 
     80   // Provide information needed for the generic arc iterator
     81   void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
     82     data->base = 0;
     83     data->arcs = arcs_ + states_[s].pos;
     84     data->narcs = states_[s].narcs;
     85     data->ref_count = 0;
     86   }
     87 
     88  private:
     89   // States implemented by array *states_ below, arcs by (single) *arcs_.
     90   struct State {
     91     Weight final;                // Final weight
     92     uint32 pos;                  // Start of state's arcs in *arcs_
     93     uint32 narcs;                // Number of arcs (per state)
     94     uint32 niepsilons;           // # of input epsilons
     95     uint32 noepsilons;           // # of output epsilons
     96     State() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {}
     97   };
     98 
     99   // Properties always true of this Fst class
    100   static const uint64 kStaticProperties = kExpanded;
    101   // Current file format version
    102   static const int kFileVersion = 1;
    103   // Minimum file format version supported
    104   static const int kMinFileVersion = 1;
    105   // Byte alignment for states and arcs in file format
    106   static const int kFileAlign = 16;
    107 
    108   State *states_;                // States represenation
    109   A *arcs_;                      // Arcs representation
    110   StateId nstates_;              // Number of states
    111   size_t narcs_;                 // Number of arcs (per FST)
    112   StateId start_;                // Initial state
    113 
    114   DISALLOW_EVIL_CONSTRUCTORS(ConstFstImpl);
    115 };
    116 
    117 template<class A>
    118 ConstFstImpl<A>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) {
    119   SetType("const");
    120   uint64 copy_properties = fst.Properties(kCopyProperties, true);
    121   SetProperties(copy_properties | kStaticProperties);
    122   SetInputSymbols(fst.InputSymbols());
    123   SetOutputSymbols(fst.OutputSymbols());
    124   start_ = fst.Start();
    125 
    126   // count # of states and arcs
    127   for (StateIterator< Fst<A> > siter(fst);
    128        !siter.Done();
    129        siter.Next()) {
    130     ++nstates_;
    131     StateId s = siter.Value();
    132     for (ArcIterator< Fst<A> > aiter(fst, s);
    133          !aiter.Done();
    134          aiter.Next())
    135       ++narcs_;
    136   }
    137   states_ = new State[nstates_];
    138   arcs_ = new A[narcs_];
    139   size_t pos = 0;
    140   for (StateId s = 0; s < nstates_; ++s) {
    141     states_[s].final = fst.Final(s);
    142     states_[s].pos = pos;
    143     states_[s].narcs = 0;
    144     states_[s].niepsilons = 0;
    145     states_[s].noepsilons = 0;
    146     for (ArcIterator< Fst<A> > aiter(fst, s);
    147          !aiter.Done();
    148          aiter.Next()) {
    149       const A &arc = aiter.Value();
    150       ++states_[s].narcs;
    151       if (arc.ilabel == 0)
    152         ++states_[s].niepsilons;
    153       if (arc.olabel == 0)
    154         ++states_[s].noepsilons;
    155       arcs_[pos++] = arc;
    156     }
    157   }
    158 }
    159 
    160 template<class A>
    161 ConstFstImpl<A> *ConstFstImpl<A>::Read(istream &strm,
    162                                        const FstReadOptions &opts) {
    163   ConstFstImpl<A> *impl = new ConstFstImpl<A>;
    164   FstHeader hdr;
    165   if (!impl->ReadHeaderAndSymbols(strm, opts, kMinFileVersion, &hdr))
    166     return 0;
    167   impl->start_ = hdr.Start();
    168   impl->nstates_ = hdr.NumStates();
    169   impl->narcs_ = hdr.NumArcs();
    170   impl->states_ = new State[impl->nstates_];
    171   impl->arcs_ = new A[impl->narcs_];
    172 
    173   char c;
    174   for (int i = 0; i < kFileAlign && strm.tellg() % kFileAlign; ++i)
    175     strm.read(&c, 1);
    176   // TODO: memory map this
    177   size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A>::State);
    178   strm.read(reinterpret_cast<char *>(impl->states_), b);
    179   if (!strm) {
    180     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    181     return 0;
    182   }
    183   // TODO: memory map this
    184   b = impl->narcs_ * sizeof(A);
    185   for (int i = 0; i < kFileAlign && strm.tellg() % kFileAlign; ++i)
    186     strm.read(&c, 1);
    187   strm.read(reinterpret_cast<char *>(impl->arcs_), b);
    188   if (!strm) {
    189     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    190     return 0;
    191   }
    192   return impl;
    193 }
    194 
    195 template<class A>
    196 bool ConstFstImpl<A>::Write(ostream &strm,
    197                             const FstWriteOptions &opts) const {
    198   FstHeader hdr;
    199   hdr.SetStart(start_);
    200   hdr.SetNumStates(nstates_);
    201   hdr.SetNumArcs(narcs_);
    202   WriteHeaderAndSymbols(strm, opts, kFileVersion, &hdr);
    203   if (!strm)
    204     return false;
    205 
    206   for (int i = 0; i < kFileAlign && strm.tellp() % kFileAlign; ++i)
    207     strm.write("", 1);
    208   strm.write(reinterpret_cast<char *>(states_),
    209              nstates_ * sizeof(State));
    210   for (int i = 0; i < kFileAlign && strm.tellp() % kFileAlign; ++i)
    211     strm.write("", 1);
    212   strm.write(reinterpret_cast<char *>(arcs_), narcs_ * sizeof(A));
    213   strm.flush();
    214   if (!strm)
    215     LOG(ERROR) << "ConstFst::Write: Write failed: " << opts.source;
    216   return strm;
    217 }
    218 
    219 // Simple concrete immutable FST.  This class attaches interface to
    220 // implementation and handles reference counting.
    221 template <class A>
    222 class ConstFst : public ExpandedFst<A> {
    223  public:
    224   friend class StateIterator< ConstFst<A> >;
    225   friend class ArcIterator< ConstFst<A> >;
    226 
    227   typedef A Arc;
    228   typedef typename A::Weight Weight;
    229   typedef typename A::StateId StateId;
    230 
    231   ConstFst() : impl_(new ConstFstImpl<A>()) {}
    232 
    233   ConstFst(const ConstFst<A> &fst) : impl_(fst.impl_) {
    234     impl_->IncrRefCount();
    235   }
    236 
    237   explicit ConstFst(const Fst<A> &fst) : impl_(new ConstFstImpl<A>(fst)) {}
    238 
    239   virtual ~ConstFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    240 
    241   virtual StateId Start() const { return impl_->Start(); }
    242 
    243   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    244 
    245   StateId NumStates() const { return impl_->NumStates(); }
    246 
    247   size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    248 
    249   size_t NumInputEpsilons(StateId s) const {
    250     return impl_->NumInputEpsilons(s);
    251   }
    252 
    253   size_t NumOutputEpsilons(StateId s) const {
    254     return impl_->NumOutputEpsilons(s);
    255   }
    256 
    257   virtual uint64 Properties(uint64 mask, bool test) const {
    258     if (test) {
    259       uint64 known, test = TestProperties(*this, mask, &known);
    260       impl_->SetProperties(test, known);
    261       return test & mask;
    262     } else {
    263       return impl_->Properties(mask);
    264     }
    265   }
    266 
    267   virtual const string& Type() const { return impl_->Type(); }
    268 
    269   // Get a copy of this ConstFst
    270   virtual ConstFst<A> *Copy() const {
    271     impl_->IncrRefCount();
    272     return new ConstFst<A>(impl_);
    273   }
    274 
    275   // Read a ConstFst from an input stream; return NULL on error
    276   static ConstFst<A> *Read(istream &strm, const FstReadOptions &opts) {
    277     ConstFstImpl<A>* impl = ConstFstImpl<A>::Read(strm, opts);
    278     return impl ? new ConstFst<A>(impl) : 0;
    279   }
    280 
    281   // Read a ConstFst from a file; returno NULL on error
    282   static ConstFst<A> *Read(const string &filename) {
    283     ifstream strm(filename.c_str());
    284     if (!strm) {
    285       LOG(ERROR) << "ConstFst::Write: Can't open file: " << filename;
    286       return 0;
    287     }
    288     return Read(strm, FstReadOptions(filename));
    289   }
    290 
    291   // Write a ConstFst to an output stream; return false on error
    292   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
    293     return impl_->Write(strm, opts);
    294   }
    295 
    296   // Write a ConstFst to a file; return false on error
    297   virtual bool Write(const string &filename) const {
    298     if (!filename.empty()) {
    299       ofstream strm(filename.c_str());
    300       if (!strm) {
    301         LOG(ERROR) << "ConstrFst::Write: Can't open file: " << filename;
    302         return false;
    303       }
    304       return Write(strm, FstWriteOptions(filename));
    305     } else {
    306       return Write(std::cout, FstWriteOptions("standard output"));
    307     }
    308   }
    309 
    310   virtual const SymbolTable* InputSymbols() const {
    311     return impl_->InputSymbols();
    312   }
    313 
    314   virtual const SymbolTable* OutputSymbols() const {
    315     return impl_->OutputSymbols();
    316   }
    317 
    318   virtual void InitStateIterator(StateIteratorData<A> *data) const {
    319     impl_->InitStateIterator(data);
    320   }
    321 
    322   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    323     impl_->InitArcIterator(s, data);
    324   }
    325 
    326  private:
    327   ConstFst(ConstFstImpl<A> *impl) : impl_(impl) {}
    328 
    329   ConstFstImpl<A> *impl_;  // FST's impl
    330 
    331   void operator=(const ConstFst<A> &fst);  // disallow
    332 };
    333 
    334 // Specialization for ConstFst; see generic version in fst.h
    335 // for sample usage (but use the ConstFst type!). This version
    336 // should inline.
    337 template <class A>
    338 class StateIterator< ConstFst<A> > {
    339  public:
    340   typedef typename A::StateId StateId;
    341 
    342   explicit StateIterator(const ConstFst<A> &fst)
    343     : nstates_(fst.impl_->NumStates()), s_(0) {}
    344 
    345   bool Done() const { return s_ >= nstates_; }
    346 
    347   StateId Value() const { return s_; }
    348 
    349   void Next() { ++s_; }
    350 
    351   void Reset() { s_ = 0; }
    352 
    353  private:
    354   StateId nstates_;
    355   StateId s_;
    356 
    357   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    358 };
    359 
    360 // Specialization for ConstFst; see generic version in fst.h
    361 // for sample usage (but use the ConstFst type!). This version
    362 // should inline.
    363 template <class A>
    364 class ArcIterator< ConstFst<A> > {
    365  public:
    366   typedef typename A::StateId StateId;
    367 
    368   ArcIterator(const ConstFst<A> &fst, StateId s)
    369     : arcs_(fst.impl_->Arcs(s)), narcs_(fst.impl_->NumArcs(s)), i_(0) {}
    370 
    371   bool Done() const { return i_ >= narcs_; }
    372 
    373   const A& Value() const { return arcs_[i_]; }
    374 
    375   void Next() { ++i_; }
    376 
    377   void Reset() { i_ = 0; }
    378 
    379   void Seek(size_t a) { i_ = a; }
    380 
    381  private:
    382   const A *arcs_;
    383   size_t narcs_;
    384   size_t i_;
    385 
    386   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    387 };
    388 
    389 // A useful alias when using StdArc.
    390 typedef ConstFst<StdArc> StdConstFst;
    391 
    392 }  // namespace fst;
    393 
    394 #endif  // FST_LIB_CONST_FST_H__
    395