Home | History | Annotate | Download | only in fst
      1 // const-fst.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Simple concrete immutable FST whose states and arcs are each stored
     20 // in single arrays.
     21 
     22 #ifndef FST_LIB_CONST_FST_H__
     23 #define FST_LIB_CONST_FST_H__
     24 
     25 #include <string>
     26 #include <vector>
     27 using std::vector;
     28 
     29 #include <fst/expanded-fst.h>
     30 #include <fst/fst-decl.h>  // For optional argument declarations
     31 #include <fst/mapped-file.h>
     32 #include <fst/test-properties.h>
     33 #include <fst/util.h>
     34 
     35 
     36 namespace fst {
     37 
     38 template <class A, class U> class ConstFst;
     39 template <class F, class G> void Cast(const F &, G *);
     40 
     41 // States and arcs each implemented by single arrays, templated on the
     42 // Arc definition. The unsigned type U is used to represent indices into
     43 // the arc array.
     44 template <class A, class U>
     45 class ConstFstImpl : public FstImpl<A> {
     46  public:
     47   using FstImpl<A>::SetInputSymbols;
     48   using FstImpl<A>::SetOutputSymbols;
     49   using FstImpl<A>::SetType;
     50   using FstImpl<A>::SetProperties;
     51   using FstImpl<A>::Properties;
     52 
     53   typedef A Arc;
     54   typedef typename A::Weight Weight;
     55   typedef typename A::StateId StateId;
     56   typedef U Unsigned;
     57 
     58   ConstFstImpl()
     59       : states_region_(0), arcs_region_(0), states_(0), arcs_(0), nstates_(0),
     60         narcs_(0), start_(kNoStateId) {
     61     string type = "const";
     62     if (sizeof(U) != sizeof(uint32)) {
     63       string size;
     64       Int64ToStr(8 * sizeof(U), &size);
     65       type += size;
     66     }
     67     SetType(type);
     68     SetProperties(kNullProperties | kStaticProperties);
     69   }
     70 
     71   explicit ConstFstImpl(const Fst<A> &fst);
     72 
     73   ~ConstFstImpl() {
     74     delete arcs_region_;
     75     delete states_region_;
     76   }
     77 
     78   StateId Start() const { return start_; }
     79 
     80   Weight Final(StateId s) const { return states_[s].final; }
     81 
     82   StateId NumStates() const { return nstates_; }
     83 
     84   size_t NumArcs(StateId s) const { return states_[s].narcs; }
     85 
     86   size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; }
     87 
     88   size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; }
     89 
     90   static ConstFstImpl<A, U> *Read(istream &strm, const FstReadOptions &opts);
     91 
     92   A *Arcs(StateId s) { return arcs_ + states_[s].pos; }
     93 
     94   // Provide information needed for generic state iterator
     95   void InitStateIterator(StateIteratorData<A> *data) const {
     96     data->base = 0;
     97     data->nstates = nstates_;
     98   }
     99 
    100   // Provide information needed for the generic arc iterator
    101   void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    102     data->base = 0;
    103     data->arcs = arcs_ + states_[s].pos;
    104     data->narcs = states_[s].narcs;
    105     data->ref_count = 0;
    106   }
    107 
    108  private:
    109   friend class ConstFst<A, U>;  // Allow finding narcs_, nstates_ during Write
    110 
    111   // States implemented by array *states_ below, arcs by (single) *arcs_.
    112   struct State {
    113     Weight final;                // Final weight
    114     Unsigned pos;                // Start of state's arcs in *arcs_
    115     Unsigned narcs;              // Number of arcs (per state)
    116     Unsigned niepsilons;         // # of input epsilons
    117     Unsigned noepsilons;         // # of output epsilons
    118     State() : final(Weight::Zero()), niepsilons(0), noepsilons(0) {}
    119   };
    120 
    121   // Properties always true of this Fst class
    122   static const uint64 kStaticProperties = kExpanded;
    123   // Current unaligned file format version. The unaligned version was added and
    124   // made the default since the aligned version does not work on pipes.
    125   static const int kFileVersion = 2;
    126   // Current aligned file format version
    127   static const int kAlignedFileVersion = 1;
    128   // Minimum file format version supported
    129   static const int kMinFileVersion = 1;
    130 
    131   MappedFile *states_region_;    // Mapped file for states
    132   MappedFile *arcs_region_;      // Mapped file for arcs
    133   State *states_;                // States represenation
    134   A *arcs_;                      // Arcs representation
    135   StateId nstates_;              // Number of states
    136   size_t narcs_;                 // Number of arcs (per FST)
    137   StateId start_;                // Initial state
    138 
    139   DISALLOW_COPY_AND_ASSIGN(ConstFstImpl);
    140 };
    141 
    142 template <class A, class U>
    143 const uint64 ConstFstImpl<A, U>::kStaticProperties;
    144 template <class A, class U>
    145 const int ConstFstImpl<A, U>::kFileVersion;
    146 template <class A, class U>
    147 const int ConstFstImpl<A, U>::kAlignedFileVersion;
    148 template <class A, class U>
    149 const int ConstFstImpl<A, U>::kMinFileVersion;
    150 
    151 
    152 template<class A, class U>
    153 ConstFstImpl<A, U>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) {
    154   string type = "const";
    155   if (sizeof(U) != sizeof(uint32)) {
    156     string size;
    157     Int64ToStr(sizeof(U) * 8, &size);
    158     type += size;
    159   }
    160   SetType(type);
    161   SetInputSymbols(fst.InputSymbols());
    162   SetOutputSymbols(fst.OutputSymbols());
    163   start_ = fst.Start();
    164 
    165   // Count # of states and arcs.
    166   for (StateIterator< Fst<A> > siter(fst);
    167        !siter.Done();
    168        siter.Next()) {
    169     ++nstates_;
    170     StateId s = siter.Value();
    171     for (ArcIterator< Fst<A> > aiter(fst, s);
    172          !aiter.Done();
    173          aiter.Next())
    174       ++narcs_;
    175   }
    176   states_region_ = MappedFile::Allocate(nstates_ * sizeof(*states_));
    177   arcs_region_ = MappedFile::Allocate(narcs_ * sizeof(*arcs_));
    178   states_ = reinterpret_cast<State*>(states_region_->mutable_data());
    179   arcs_ = reinterpret_cast<A*>(arcs_region_->mutable_data());
    180   size_t pos = 0;
    181   for (StateId s = 0; s < nstates_; ++s) {
    182     states_[s].final = fst.Final(s);
    183     states_[s].pos = pos;
    184     states_[s].narcs = 0;
    185     states_[s].niepsilons = 0;
    186     states_[s].noepsilons = 0;
    187     for (ArcIterator< Fst<A> > aiter(fst, s);
    188          !aiter.Done();
    189          aiter.Next()) {
    190       const A &arc = aiter.Value();
    191       ++states_[s].narcs;
    192       if (arc.ilabel == 0)
    193         ++states_[s].niepsilons;
    194       if (arc.olabel == 0)
    195         ++states_[s].noepsilons;
    196       arcs_[pos++] = arc;
    197     }
    198   }
    199   SetProperties(fst.Properties(kCopyProperties, true) | kStaticProperties);
    200 }
    201 
    202 
    203 template<class A, class U>
    204 ConstFstImpl<A, U> *ConstFstImpl<A, U>::Read(istream &strm,
    205                                              const FstReadOptions &opts) {
    206   ConstFstImpl<A, U> *impl = new ConstFstImpl<A, U>;
    207   FstHeader hdr;
    208   if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) {
    209     delete impl;
    210     return 0;
    211   }
    212   impl->start_ = hdr.Start();
    213   impl->nstates_ = hdr.NumStates();
    214   impl->narcs_ = hdr.NumArcs();
    215 
    216   // Ensures compatibility
    217   if (hdr.Version() == kAlignedFileVersion)
    218     hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED);
    219 
    220   if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
    221     LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
    222     delete impl;
    223     return 0;
    224   }
    225 
    226   size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A, U>::State);
    227   impl->states_region_ = MappedFile::Map(&strm, opts, b);
    228   if (!strm || impl->states_region_ == NULL) {
    229     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    230     delete impl;
    231     return 0;
    232   }
    233   impl->states_ = reinterpret_cast<State*>(
    234       impl->states_region_->mutable_data());
    235   if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
    236     LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
    237     delete impl;
    238     return 0;
    239   }
    240 
    241   b = impl->narcs_ * sizeof(A);
    242   impl->arcs_region_ = MappedFile::Map(&strm, opts, b);
    243   if (!strm || impl->arcs_region_ == NULL) {
    244     LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
    245     delete impl;
    246     return 0;
    247   }
    248   impl->arcs_ = reinterpret_cast<A*>(impl->arcs_region_->mutable_data());
    249   return impl;
    250 }
    251 
    252 // Simple concrete immutable FST.  This class attaches interface to
    253 // implementation and handles reference counting, delegating most
    254 // methods to ImplToExpandedFst. The unsigned type U is used to
    255 // represent indices into the arc array (uint32 by default, declared
    256 // in fst-decl.h).
    257 template <class A, class U>
    258 class ConstFst : public ImplToExpandedFst< ConstFstImpl<A, U> > {
    259  public:
    260   friend class StateIterator< ConstFst<A, U> >;
    261   friend class ArcIterator< ConstFst<A, U> >;
    262   template <class F, class G> void friend Cast(const F &, G *);
    263 
    264   typedef A Arc;
    265   typedef typename A::StateId StateId;
    266   typedef ConstFstImpl<A, U> Impl;
    267   typedef U Unsigned;
    268 
    269   ConstFst() : ImplToExpandedFst<Impl>(new Impl()) {}
    270 
    271   explicit ConstFst(const Fst<A> &fst)
    272       : ImplToExpandedFst<Impl>(new Impl(fst)) {}
    273 
    274   ConstFst(const ConstFst<A, U> &fst) : ImplToExpandedFst<Impl>(fst) {}
    275 
    276   // Get a copy of this ConstFst. See Fst<>::Copy() for further doc.
    277   virtual ConstFst<A, U> *Copy(bool safe = false) const {
    278     return new ConstFst<A, U>(*this);
    279   }
    280 
    281   // Read a ConstFst from an input stream; return NULL on error
    282   static ConstFst<A, U> *Read(istream &strm, const FstReadOptions &opts) {
    283     Impl* impl = Impl::Read(strm, opts);
    284     return impl ? new ConstFst<A, U>(impl) : 0;
    285   }
    286 
    287   // Read a ConstFst from a file; return NULL on error
    288   // Empty filename reads from standard input
    289   static ConstFst<A, U> *Read(const string &filename) {
    290     Impl* impl = ImplToExpandedFst<Impl>::Read(filename);
    291     return impl ? new ConstFst<A, U>(impl) : 0;
    292   }
    293 
    294   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
    295     return WriteFst(*this, strm, opts);
    296   }
    297 
    298   virtual bool Write(const string &filename) const {
    299     return Fst<A>::WriteFile(filename);
    300   }
    301 
    302   template <class F>
    303   static bool WriteFst(const F &fst, ostream &strm,
    304                        const FstWriteOptions &opts);
    305 
    306   virtual void InitStateIterator(StateIteratorData<Arc> *data) const {
    307     GetImpl()->InitStateIterator(data);
    308   }
    309 
    310   virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    311     GetImpl()->InitArcIterator(s, data);
    312   }
    313 
    314  private:
    315   explicit ConstFst(Impl *impl) : ImplToExpandedFst<Impl>(impl) {}
    316 
    317   // Makes visible to friends.
    318   Impl *GetImpl() const { return ImplToFst<Impl, ExpandedFst<A> >::GetImpl(); }
    319 
    320   void SetImpl(Impl *impl, bool own_impl = true) {
    321     ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl);
    322   }
    323 
    324   // Use overloading to extract the type of the argument.
    325   static Impl* GetImplIfConstFst(const ConstFst &const_fst) {
    326     return const_fst.GetImpl();
    327   }
    328 
    329   // Note that this does not give privileged treatment to subtypes of ConstFst.
    330   template<typename NonConstFst>
    331   static Impl* GetImplIfConstFst(const NonConstFst& fst) {
    332     return NULL;
    333   }
    334 
    335   void operator=(const ConstFst<A, U> &fst);  // disallow
    336 };
    337 
    338 // Writes Fst in Const format, potentially with a pass over the machine
    339 // before writing to compute number of states and arcs.
    340 //
    341 template <class A, class U>
    342 template <class F>
    343 bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm,
    344                               const FstWriteOptions &opts) {
    345   int file_version = opts.align ? ConstFstImpl<A, U>::kAlignedFileVersion :
    346       ConstFstImpl<A, U>::kFileVersion;
    347   size_t num_arcs = -1, num_states = -1;
    348   size_t start_offset = 0;
    349   bool update_header = true;
    350   if (Impl* impl = GetImplIfConstFst(fst)) {
    351     num_arcs = impl->narcs_;
    352     num_states = impl->nstates_;
    353     update_header = false;
    354   } else if ((start_offset = strm.tellp()) == -1) {
    355     // precompute values needed for header when we cannot seek to rewrite it.
    356     num_arcs = 0;
    357     num_states = 0;
    358     for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
    359       num_arcs += fst.NumArcs(siter.Value());
    360       ++num_states;
    361     }
    362     update_header = false;
    363   }
    364   FstHeader hdr;
    365   hdr.SetStart(fst.Start());
    366   hdr.SetNumStates(num_states);
    367   hdr.SetNumArcs(num_arcs);
    368   string type = "const";
    369   if (sizeof(U) != sizeof(uint32)) {
    370     string size;
    371     Int64ToStr(8 * sizeof(U), &size);
    372     type += size;
    373   }
    374   uint64 properties = fst.Properties(kCopyProperties, true) |
    375       ConstFstImpl<A, U>::kStaticProperties;
    376   FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties,
    377                              &hdr);
    378   if (opts.align && !AlignOutput(strm)) {
    379     LOG(ERROR) << "Could not align file during write after header";
    380     return false;
    381   }
    382   size_t pos = 0, states = 0;
    383   typename ConstFstImpl<A, U>::State state;
    384   for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
    385     state.final = fst.Final(siter.Value());
    386     state.pos = pos;
    387     state.narcs = fst.NumArcs(siter.Value());
    388     state.niepsilons = fst.NumInputEpsilons(siter.Value());
    389     state.noepsilons = fst.NumOutputEpsilons(siter.Value());
    390     strm.write(reinterpret_cast<const char *>(&state), sizeof(state));
    391     pos += state.narcs;
    392     ++states;
    393   }
    394   hdr.SetNumStates(states);
    395   hdr.SetNumArcs(pos);
    396   if (opts.align && !AlignOutput(strm)) {
    397     LOG(ERROR) << "Could not align file during write after writing states";
    398   }
    399   for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
    400     StateId s = siter.Value();
    401     for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
    402       const A &arc = aiter.Value();
    403       strm.write(reinterpret_cast<const char *>(&arc), sizeof(arc));
    404     }
    405   }
    406   strm.flush();
    407   if (!strm) {
    408     LOG(ERROR) << "ConstFst Write write failed: " << opts.source;
    409     return false;
    410   }
    411   if (update_header) {
    412     return FstImpl<A>::UpdateFstHeader(fst, strm, opts, file_version, type,
    413                                        properties, &hdr, start_offset);
    414   } else {
    415     if (hdr.NumStates() != num_states) {
    416       LOG(ERROR) << "Inconsistent number of states observed during write";
    417       return false;
    418     }
    419     if (hdr.NumArcs() != num_arcs) {
    420       LOG(ERROR) << "Inconsistent number of arcs observed during write";
    421       return false;
    422     }
    423   }
    424   return true;
    425 }
    426 
    427 // Specialization for ConstFst; see generic version in fst.h
    428 // for sample usage (but use the ConstFst type!). This version
    429 // should inline.
    430 template <class A, class U>
    431 class StateIterator< ConstFst<A, U> > {
    432  public:
    433   typedef typename A::StateId StateId;
    434 
    435   explicit StateIterator(const ConstFst<A, U> &fst)
    436       : nstates_(fst.GetImpl()->NumStates()), s_(0) {}
    437 
    438   bool Done() const { return s_ >= nstates_; }
    439 
    440   StateId Value() const { return s_; }
    441 
    442   void Next() { ++s_; }
    443 
    444   void Reset() { s_ = 0; }
    445 
    446  private:
    447   StateId nstates_;
    448   StateId s_;
    449 
    450   DISALLOW_COPY_AND_ASSIGN(StateIterator);
    451 };
    452 
    453 
    454 // Specialization for ConstFst; see generic version in fst.h
    455 // for sample usage (but use the ConstFst type!). This version
    456 // should inline.
    457 template <class A, class U>
    458 class ArcIterator< ConstFst<A, U> > {
    459  public:
    460   typedef typename A::StateId StateId;
    461 
    462   ArcIterator(const ConstFst<A, U> &fst, StateId s)
    463       : arcs_(fst.GetImpl()->Arcs(s)),
    464         narcs_(fst.GetImpl()->NumArcs(s)), i_(0) {}
    465 
    466   bool Done() const { return i_ >= narcs_; }
    467 
    468   const A& Value() const { return arcs_[i_]; }
    469 
    470   void Next() { ++i_; }
    471 
    472   size_t Position() const { return i_; }
    473 
    474   void Reset() { i_ = 0; }
    475 
    476   void Seek(size_t a) { i_ = a; }
    477 
    478   uint32 Flags() const {
    479     return kArcValueFlags;
    480   }
    481 
    482   void SetFlags(uint32 f, uint32 m) {}
    483 
    484  private:
    485   const A *arcs_;
    486   size_t narcs_;
    487   size_t i_;
    488 
    489   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    490 };
    491 
    492 // A useful alias when using StdArc.
    493 typedef ConstFst<StdArc> StdConstFst;
    494 
    495 }  // namespace fst
    496 
    497 #endif  // FST_LIB_CONST_FST_H__
    498