Home | History | Annotate | Download | only in lib
      1 // const-fst.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Simple concrete immutable FST whose states and arcs are each stored
     18 // in single arrays.
     19 
     20 #ifndef FST_LIB_CONST_FST_H__
     21 #define FST_LIB_CONST_FST_H__
     22 
     23 #include "fst/lib/expanded-fst.h"
     24 #include "fst/lib/test-properties.h"
     25 
     26 namespace fst {
     27 
     28 template <class A> class ConstFst;
     29 
     30 // States and arcs each implemented by single arrays, templated on the
     31 // Arc definition.
     32 template <class A>
     33 class ConstFstImpl : public FstImpl<A> {
     34  public:
     35   using FstImpl<A>::SetType;
     36   using FstImpl<A>::SetProperties;
     37   using FstImpl<A>::Properties;
     38   using FstImpl<A>::WriteHeaderAndSymbols;
     39 
     40   typedef typename A::Weight Weight;
     41   typedef typename A::StateId StateId;
     42 
     43   ConstFstImpl()
     44       : states_(0), arcs_(0), nstates_(0), narcs_(0), start_(kNoStateId) {
     45     SetType("const");
     46     SetProperties(kNullProperties | kStaticProperties);
     47   }
     48 
     49   explicit ConstFstImpl(const Fst<A> &fst);
     50 
     51   ~ConstFstImpl() {
     52     delete[] states_;
     53     delete[] arcs_;
     54   }
     55 
     56   StateId Start() const { return start_; }
     57 
     58   Weight Final(StateId s) const { return states_[s].final; }
     59 
     60   StateId NumStates() const { return nstates_; }
     61 
     62   size_t NumArcs(StateId s) const { return states_[s].narcs; }
     63 
     64   size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; }
     65 
     66   size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; }
     67 
     68   static ConstFstImpl<A> *Read(istream &strm, const FstReadOptions &opts);
     69 
     70   bool Write(ostream &strm, const FstWriteOptions &opts) const;
     71 
     72   A *Arcs(StateId s) { return arcs_ + states_[s].pos; }
     73 
     74   // Provide information needed for generic state iterator
     75   void InitStateIterator(StateIteratorData<A> *data) const {
     76     data->base = 0;
     77     data->nstates = nstates_;
     78   }
     79 
     80   // Provide information needed for the generic arc iterator
     81   void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
     82     data->base = 0;
     83     data->arcs = arcs_ + states_[s].pos;
     84     data->narcs = states_[s].narcs;
     85     data->ref_count = 0;
     86   }
     87 
     88  private:
     89   // States implemented by array *states_ below, arcs by (single) *arcs_.
     90   struct State {
     91     Weight final;                // Final weight
     92     uint32 pos;                  // Start of state's arcs in *arcs_
     93     uint32 narcs;                // Number of arcs (per state)
     94     uint32 niepsilons;           // # of input epsilons
     95     uint32 noepsilons;           // # of output epsilons
     96     State() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {}
     97   };
     98 
     99   // Properties always true of this Fst class
    100   static const uint64 kStaticProperties = kExpanded;
    101   // Current file format version
    102   static const int kFileVersion = 1;
    103   // Minimum file format version supported
    104   static const int kMinFileVersion = 1;
    105   // Byte alignment for states and arcs in file format
    106   static const int kFileAlign = 16;
    107 
    108   State *states_;                // States represenation
    109   A *arcs_;                      // Arcs representation
    110   StateId nstates_;              // Number of states
    111   size_t narcs_;                 // Number of arcs (per FST)
    112   StateId start_;                // Initial state
    113 
    114   DISALLOW_EVIL_CONSTRUCTORS(ConstFstImpl);
    115 };
    116 
    117 template<class A>
    118 ConstFstImpl<A>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) {
    119   SetType("const");
    120   uint64 copy_properties = fst.Properties(kCopyProperties, true);
    121   SetProperties(copy_properties | kStaticProperties);
    122   this->SetInputSymbols(fst.InputSymbols());
    123   this->SetOutputSymbols(fst.OutputSymbols());
    124   start_ = fst.Start();
    125 
    126   // count # of states and arcs
    127   for (StateIterator< Fst<A> > siter(fst);
    128        !siter.Done();
    129        siter.Next()) {
    130     ++nstates_;
    131     StateId s = siter.Value();
    132     for (ArcIterator< Fst<A> > aiter(fst, s);
    133          !aiter.Done();
    134          aiter.Next())
    135       ++narcs_;
    136   }
    137   states_ = new State[nstates_];
    138   arcs_ = new A[narcs_];
    139   size_t pos = 0;
    140   for (StateId s = 0; s < nstates_; ++s) {
    141     states_[s].final = fst.Final(s);
    142     states_[s].pos = pos;
    143     states_[s].narcs = 0;
    144     states_[s].niepsilons = 0;
    145     states_[s].noepsilons = 0;
    146     for (ArcIterator< Fst<A> > aiter(fst, s);
    147          !aiter.Done();
    148          aiter.Next()) {
    149       const A &arc = aiter.Value();
    150       ++states_[s].narcs;
    151       if (arc.ilabel == 0)
    152         ++states_[s].niepsilons;
    153       if (arc.olabel == 0)
    154         ++states_[s].noepsilons;
    155       arcs_[pos++] = arc;
    156     }
    157   }
    158 }
    159 
    160 template<class A>
    161 ConstFstImpl<A> *ConstFstImpl<A>::Read(istream &strm,
    162                                        const FstReadOptions &opts) {
    163   ConstFstImpl<A> *impl = new ConstFstImpl<A>;
    164   FstHeader hdr;
    165   if (!impl->ReadHeaderAndSymbols(strm, opts, kMinFileVersion, &hdr))
    166     return 0;
    167   impl->start_ = hdr.Start();
    168   impl->nstates_ = hdr.NumStates();
    169   impl->narcs_ = hdr.NumArcs();
    170   impl->states_ = new State[impl->nstates_];
    171   impl->arcs_ = new A[impl->narcs_];
    172 
    173   char c;
    174   for (int i = 0; i < kFileAlign && strm.tellg() % kFileAlign; ++i)
    175     strm.read(&c, 1);
    176   // TODO: memory map this
    177   size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A>::State);
    178   strm.read(reinterpret_cast<char *>(impl->states_), b);
    179   if (!strm) {
    180     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    181     return 0;
    182   }
    183   // TODO: memory map this
    184   b = impl->narcs_ * sizeof(A);
    185   for (int i = 0; i < kFileAlign && strm.tellg() % kFileAlign; ++i)
    186     strm.read(&c, 1);
    187   strm.read(reinterpret_cast<char *>(impl->arcs_), b);
    188   if (!strm) {
    189     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    190     return 0;
    191   }
    192   return impl;
    193 }
    194 
    195 template<class A>
    196 bool ConstFstImpl<A>::Write(ostream &strm,
    197                             const FstWriteOptions &opts) const {
    198   FstHeader hdr;
    199   hdr.SetStart(start_);
    200   hdr.SetNumStates(nstates_);
    201   hdr.SetNumArcs(narcs_);
    202   WriteHeaderAndSymbols(strm, opts, kFileVersion, &hdr);
    203   if (!strm)
    204     return false;
    205 
    206   for (int i = 0; i < kFileAlign && strm.tellp() % kFileAlign; ++i)
    207     strm.write("", 1);
    208   strm.write(reinterpret_cast<char *>(states_),
    209              nstates_ * sizeof(State));
    210   for (int i = 0; i < kFileAlign && strm.tellp() % kFileAlign; ++i)
    211     strm.write("", 1);
    212   strm.write(reinterpret_cast<char *>(arcs_), narcs_ * sizeof(A));
    213   strm.flush();
    214   if (!strm) {
    215     LOG(ERROR) << "ConstFst::Write: Write failed: " << opts.source;
    216     return false;
    217   }
    218   return true;
    219 }
    220 
    221 // Simple concrete immutable FST.  This class attaches interface to
    222 // implementation and handles reference counting.
    223 template <class A>
    224 class ConstFst : public ExpandedFst<A> {
    225  public:
    226   friend class StateIterator< ConstFst<A> >;
    227   friend class ArcIterator< ConstFst<A> >;
    228 
    229   typedef A Arc;
    230   typedef typename A::Weight Weight;
    231   typedef typename A::StateId StateId;
    232 
    233   ConstFst() : impl_(new ConstFstImpl<A>()) {}
    234 
    235   ConstFst(const ConstFst<A> &fst) : impl_(fst.impl_) {
    236     impl_->IncrRefCount();
    237   }
    238 
    239   explicit ConstFst(const Fst<A> &fst) : impl_(new ConstFstImpl<A>(fst)) {}
    240 
    241   virtual ~ConstFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    242 
    243   virtual StateId Start() const { return impl_->Start(); }
    244 
    245   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    246 
    247   StateId NumStates() const { return impl_->NumStates(); }
    248 
    249   size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    250 
    251   size_t NumInputEpsilons(StateId s) const {
    252     return impl_->NumInputEpsilons(s);
    253   }
    254 
    255   size_t NumOutputEpsilons(StateId s) const {
    256     return impl_->NumOutputEpsilons(s);
    257   }
    258 
    259   virtual uint64 Properties(uint64 mask, bool test) const {
    260     if (test) {
    261       uint64 known, test = TestProperties(*this, mask, &known);
    262       impl_->SetProperties(test, known);
    263       return test & mask;
    264     } else {
    265       return impl_->Properties(mask);
    266     }
    267   }
    268 
    269   virtual const string& Type() const { return impl_->Type(); }
    270 
    271   // Get a copy of this ConstFst
    272   virtual ConstFst<A> *Copy() const {
    273     impl_->IncrRefCount();
    274     return new ConstFst<A>(impl_);
    275   }
    276 
    277   // Read a ConstFst from an input stream; return NULL on error
    278   static ConstFst<A> *Read(istream &strm, const FstReadOptions &opts) {
    279     ConstFstImpl<A>* impl = ConstFstImpl<A>::Read(strm, opts);
    280     return impl ? new ConstFst<A>(impl) : 0;
    281   }
    282 
    283   // Read a ConstFst from a file; returno NULL on error
    284   static ConstFst<A> *Read(const string &filename) {
    285     ifstream strm(filename.c_str());
    286     if (!strm) {
    287       LOG(ERROR) << "ConstFst::Write: Can't open file: " << filename;
    288       return 0;
    289     }
    290     return Read(strm, FstReadOptions(filename));
    291   }
    292 
    293   // Write a ConstFst to an output stream; return false on error
    294   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
    295     return impl_->Write(strm, opts);
    296   }
    297 
    298   // Write a ConstFst to a file; return false on error
    299   virtual bool Write(const string &filename) const {
    300     if (!filename.empty()) {
    301       ofstream strm(filename.c_str());
    302       if (!strm) {
    303         LOG(ERROR) << "ConstrFst::Write: Can't open file: " << filename;
    304         return false;
    305       }
    306       return Write(strm, FstWriteOptions(filename));
    307     } else {
    308       return Write(std::cout, FstWriteOptions("standard output"));
    309     }
    310   }
    311 
    312   virtual const SymbolTable* InputSymbols() const {
    313     return impl_->InputSymbols();
    314   }
    315 
    316   virtual const SymbolTable* OutputSymbols() const {
    317     return impl_->OutputSymbols();
    318   }
    319 
    320   virtual void InitStateIterator(StateIteratorData<A> *data) const {
    321     impl_->InitStateIterator(data);
    322   }
    323 
    324   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    325     impl_->InitArcIterator(s, data);
    326   }
    327 
    328  private:
    329   ConstFst(ConstFstImpl<A> *impl) : impl_(impl) {}
    330 
    331   ConstFstImpl<A> *impl_;  // FST's impl
    332 
    333   void operator=(const ConstFst<A> &fst);  // disallow
    334 };
    335 
    336 // Specialization for ConstFst; see generic version in fst.h
    337 // for sample usage (but use the ConstFst type!). This version
    338 // should inline.
    339 template <class A>
    340 class StateIterator< ConstFst<A> > {
    341  public:
    342   typedef typename A::StateId StateId;
    343 
    344   explicit StateIterator(const ConstFst<A> &fst)
    345     : nstates_(fst.impl_->NumStates()), s_(0) {}
    346 
    347   bool Done() const { return s_ >= nstates_; }
    348 
    349   StateId Value() const { return s_; }
    350 
    351   void Next() { ++s_; }
    352 
    353   void Reset() { s_ = 0; }
    354 
    355  private:
    356   StateId nstates_;
    357   StateId s_;
    358 
    359   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
    360 };
    361 
    362 // Specialization for ConstFst; see generic version in fst.h
    363 // for sample usage (but use the ConstFst type!). This version
    364 // should inline.
    365 template <class A>
    366 class ArcIterator< ConstFst<A> > {
    367  public:
    368   typedef typename A::StateId StateId;
    369 
    370   ArcIterator(const ConstFst<A> &fst, StateId s)
    371     : arcs_(fst.impl_->Arcs(s)), narcs_(fst.impl_->NumArcs(s)), i_(0) {}
    372 
    373   bool Done() const { return i_ >= narcs_; }
    374 
    375   const A& Value() const { return arcs_[i_]; }
    376 
    377   void Next() { ++i_; }
    378 
    379   void Reset() { i_ = 0; }
    380 
    381   void Seek(size_t a) { i_ = a; }
    382 
    383  private:
    384   const A *arcs_;
    385   size_t narcs_;
    386   size_t i_;
    387 
    388   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    389 };
    390 
    391 // A useful alias when using StdArc.
    392 typedef ConstFst<StdArc> StdConstFst;
    393 
    394 }  // namespace fst;
    395 
    396 #endif  // FST_LIB_CONST_FST_H__
    397