Home | History | Annotate | Download | only in fst
      1 // add-on.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Fst implementation class to attach an arbitrary object with a
     20 // read/write method to an FST and its file rep. The FST is given a
     21 // new type name.
     22 
     23 #ifndef FST_LIB_ADD_ON_FST_H__
     24 #define FST_LIB_ADD_ON_FST_H__
     25 
     26 #include <stddef.h>
     27 #include <string>
     28 
     29 #include <fst/fst.h>
     30 
     31 
     32 namespace fst {
     33 
     34 // Identifies stream data as an add-on fst.
     35 static const int32 kAddOnMagicNumber = 446681434;
     36 
     37 
     38 //
     39 // Some useful add-on objects.
     40 //
     41 
     42 // Nothing to save.
     43 class NullAddOn {
     44  public:
     45   NullAddOn() {}
     46 
     47   static NullAddOn *Read(istream &istrm) {
     48     return new NullAddOn();
     49   };
     50 
     51   bool Write(ostream &ostrm) const { return true; }
     52 
     53   int RefCount() const { return ref_count_.count(); }
     54   int IncrRefCount() { return ref_count_.Incr(); }
     55   int DecrRefCount() { return ref_count_.Decr(); }
     56 
     57  private:
     58   RefCounter ref_count_;
     59 
     60   DISALLOW_COPY_AND_ASSIGN(NullAddOn);
     61 };
     62 
     63 
     64 // Create a new add-on from a pair of add-ons.
     65 template <class A1, class A2>
     66 class AddOnPair {
     67  public:
     68   // Argument reference count incremented.
     69   AddOnPair(A1 *a1, A2 *a2)
     70       : a1_(a1), a2_(a2) {
     71     if (a1_)
     72       a1_->IncrRefCount();
     73     if (a2_)
     74       a2_->IncrRefCount();
     75   }
     76 
     77   ~AddOnPair() {
     78     if (a1_ && !a1_->DecrRefCount())
     79       delete a1_;
     80     if (a2_ && !a2_->DecrRefCount())
     81       delete a2_;
     82   }
     83 
     84   A1 *First() const { return a1_; }
     85   A2 *Second() const { return a2_; }
     86 
     87   static AddOnPair<A1, A2> *Read(istream &istrm) {
     88     A1 *a1 = 0;
     89     bool have_addon1 = false;
     90     ReadType(istrm, &have_addon1);
     91     if (have_addon1)
     92       a1 = A1::Read(istrm);
     93 
     94     A2 *a2 = 0;
     95     bool have_addon2 = false;
     96     ReadType(istrm, &have_addon2);
     97     if (have_addon2)
     98       a2 = A2::Read(istrm);
     99 
    100     AddOnPair<A1, A2> *a = new AddOnPair<A1, A2>(a1, a2);
    101     if (a1)
    102       a1->DecrRefCount();
    103     if (a2)
    104       a2->DecrRefCount();
    105     return a;
    106   };
    107 
    108   bool Write(ostream &ostrm) const {
    109     bool have_addon1 = a1_;
    110     WriteType(ostrm, have_addon1);
    111     if (have_addon1)
    112       a1_->Write(ostrm);
    113     bool have_addon2 = a2_;
    114     WriteType(ostrm, have_addon2);
    115     if (have_addon2)
    116       a2_->Write(ostrm);
    117     return true;
    118   }
    119 
    120   int RefCount() const { return ref_count_.count(); }
    121 
    122   int IncrRefCount() {
    123     return ref_count_.Incr();
    124   }
    125 
    126   int DecrRefCount() {
    127     return ref_count_.Decr();
    128   }
    129 
    130  private:
    131   A1 *a1_;
    132   A2 *a2_;
    133   RefCounter ref_count_;
    134 
    135   DISALLOW_COPY_AND_ASSIGN(AddOnPair);
    136 };
    137 
    138 
    139 // Add to an Fst F a type T object. T must have a 'T* Read(istream &)',
    140 // a 'bool Write(ostream &)' method, and 'int RecCount(), 'int IncrRefCount()'
    141 // and 'int DecrRefCount()' methods (e.g. 'MatcherData' in matcher-fst.h).
    142 // The result is a new Fst implemenation with type name 'type'.
    143 template<class F, class T>
    144 class AddOnImpl : public FstImpl<typename F::Arc> {
    145  public:
    146   typedef typename F::Arc Arc;
    147   typedef typename Arc::Label Label;
    148   typedef typename Arc::Weight Weight;
    149   typedef typename Arc::StateId StateId;
    150 
    151   using FstImpl<Arc>::SetType;
    152   using FstImpl<Arc>::SetProperties;
    153   using FstImpl<Arc>::WriteHeader;
    154 
    155   // If 't' is non-zero, its reference count is incremented.
    156   AddOnImpl(const F &fst, const string &type, T *t = 0)
    157       : fst_(fst), t_(t) {
    158     SetType(type);
    159     SetProperties(fst_.Properties(kFstProperties, false));
    160     if (t_)
    161       t_->IncrRefCount();
    162   }
    163 
    164   // If 't' is non-zero, its reference count is incremented.
    165   AddOnImpl(const Fst<Arc> &fst, const string &type, T *t = 0)
    166       : fst_(fst), t_(t) {
    167     SetType(type);
    168     SetProperties(fst_.Properties(kFstProperties, false));
    169     if (t_)
    170       t_->IncrRefCount();
    171   }
    172 
    173   AddOnImpl(const AddOnImpl<F, T> &impl)
    174       : fst_(impl.fst_), t_(impl.t_) {
    175     SetType(impl.Type());
    176     SetProperties(fst_.Properties(kCopyProperties, false));
    177     if (t_)
    178       t_->IncrRefCount();
    179   }
    180 
    181   ~AddOnImpl() {
    182     if (t_ && !t_->DecrRefCount())
    183       delete t_;
    184   }
    185 
    186   StateId Start() const { return fst_.Start(); }
    187   Weight Final(StateId s) const { return fst_.Final(s); }
    188   size_t NumArcs(StateId s) const { return fst_.NumArcs(s); }
    189 
    190   size_t NumInputEpsilons(StateId s) const {
    191     return fst_.NumInputEpsilons(s);
    192   }
    193 
    194   size_t NumOutputEpsilons(StateId s) const {
    195     return fst_.NumOutputEpsilons(s);
    196   }
    197 
    198   size_t NumStates() const { return fst_.NumStates(); }
    199 
    200   static AddOnImpl<F, T> *Read(istream &strm, const FstReadOptions &opts) {
    201     FstReadOptions nopts(opts);
    202     FstHeader hdr;
    203     if (!nopts.header) {
    204       hdr.Read(strm, nopts.source);
    205       nopts.header = &hdr;
    206     }
    207     AddOnImpl<F, T> *impl = new AddOnImpl<F, T>(nopts.header->FstType());
    208     if (!impl->ReadHeader(strm, nopts, kMinFileVersion, &hdr))
    209       return 0;
    210     delete impl;       // Used here only for checking types.
    211 
    212     int32 magic_number = 0;
    213     ReadType(strm, &magic_number);   // Ensures this is an add-on Fst.
    214     if (magic_number != kAddOnMagicNumber) {
    215       LOG(ERROR) << "AddOnImpl::Read: Bad add-on header: " << nopts.source;
    216       return 0;
    217     }
    218 
    219     FstReadOptions fopts(opts);
    220     fopts.header = 0;  // Contained header was written out.
    221     F *fst = F::Read(strm, fopts);
    222     if (!fst)
    223       return 0;
    224 
    225     T *t = 0;
    226     bool have_addon = false;
    227     ReadType(strm, &have_addon);
    228     if (have_addon) {   // Read add-on object if present.
    229       t = T::Read(strm);
    230       if (!t)
    231         return 0;
    232     }
    233     impl = new AddOnImpl<F, T>(*fst, nopts.header->FstType(), t);
    234     delete fst;
    235     if (t)
    236       t->DecrRefCount();
    237     return impl;
    238   }
    239 
    240   bool Write(ostream &strm, const FstWriteOptions &opts) const {
    241     FstHeader hdr;
    242     FstWriteOptions nopts(opts);
    243     nopts.write_isymbols = false;  // Let contained FST hold any symbols.
    244     nopts.write_osymbols = false;
    245     WriteHeader(strm, nopts, kFileVersion, &hdr);
    246     WriteType(strm, kAddOnMagicNumber);  // Ensures this is an add-on Fst.
    247     FstWriteOptions fopts(opts);
    248     fopts.write_header = true;     // Force writing contained header.
    249     if (!fst_.Write(strm, fopts))
    250       return false;
    251     bool have_addon = t_;
    252     WriteType(strm, have_addon);
    253     if (have_addon)                // Write add-on object if present.
    254       t_->Write(strm);
    255     return true;
    256   }
    257 
    258   void InitStateIterator(StateIteratorData<Arc> *data) const {
    259     fst_.InitStateIterator(data);
    260   }
    261 
    262   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    263     fst_.InitArcIterator(s, data);
    264   }
    265 
    266   F &GetFst() { return fst_; }
    267 
    268   const F &GetFst() const { return fst_; }
    269 
    270   T *GetAddOn() const { return t_; }
    271 
    272   // If 't' is non-zero, its reference count is incremented.
    273   void SetAddOn(T *t) {
    274     if (t == t_)
    275       return;
    276     if (t_ && !t_->DecrRefCount())
    277       delete t_;
    278     t_ = t;
    279     if (t_)
    280       t_->IncrRefCount();
    281   }
    282 
    283  private:
    284   explicit AddOnImpl(const string &type) : t_(0) {
    285     SetType(type);
    286     SetProperties(kExpanded);
    287   }
    288 
    289   // Current file format version
    290   static const int kFileVersion = 1;
    291   // Minimum file format version supported
    292   static const int kMinFileVersion = 1;
    293 
    294   F fst_;
    295   T *t_;
    296 
    297   void operator=(const AddOnImpl<F, T> &fst);  // Disallow
    298 };
    299 
    300 template <class F, class T> const int AddOnImpl<F, T>::kFileVersion;
    301 template <class F, class T> const int AddOnImpl<F, T>::kMinFileVersion;
    302 
    303 
    304 }  // namespace fst
    305 
    306 #endif  // FST_LIB_ADD_ON_FST_H__
    307