Home | History | Annotate | Download | only in fst
      1 // mutable-fst.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Expanded FST augmented with mutators - interface class definition
     20 // and mutable arc iterator interface.
     21 //
     22 
     23 #ifndef FST_LIB_MUTABLE_FST_H__
     24 #define FST_LIB_MUTABLE_FST_H__
     25 
     26 #include <stddef.h>
     27 #include <sys/types.h>
     28 #include <string>
     29 #include <vector>
     30 using std::vector;
     31 
     32 #include <fst/expanded-fst.h>
     33 
     34 
     35 namespace fst {
     36 
     37 template <class A> class MutableArcIteratorData;
     38 
     39 // An expanded FST plus mutators (use MutableArcIterator to modify arcs).
     40 template <class A>
     41 class MutableFst : public ExpandedFst<A> {
     42  public:
     43   typedef A Arc;
     44   typedef typename A::Weight Weight;
     45   typedef typename A::StateId StateId;
     46 
     47   virtual MutableFst<A> &operator=(const Fst<A> &fst) = 0;
     48 
     49   MutableFst<A> &operator=(const MutableFst<A> &fst) {
     50     return operator=(static_cast<const Fst<A> &>(fst));
     51   }
     52 
     53   virtual void SetStart(StateId) = 0;           // Set the initial state
     54   virtual void SetFinal(StateId, Weight) = 0;   // Set a state's final weight
     55   virtual void SetProperties(uint64 props,
     56                              uint64 mask) = 0;  // Set property bits wrt mask
     57 
     58   virtual StateId AddState() = 0;               // Add a state, return its ID
     59   virtual void AddArc(StateId, const A &arc) = 0;   // Add an arc to state
     60 
     61   virtual void DeleteStates(const vector<StateId>&) = 0;  // Delete some states
     62   virtual void DeleteStates() = 0;              // Delete all states
     63   virtual void DeleteArcs(StateId, size_t n) = 0;  // Delete some arcs at state
     64   virtual void DeleteArcs(StateId) = 0;         // Delete all arcs at state
     65 
     66   virtual void ReserveStates(StateId n) { }  // Optional, best effort only.
     67   virtual void ReserveArcs(StateId s, size_t n) { }  // Optional, Best effort.
     68 
     69   // Return input label symbol table; return NULL if not specified
     70   virtual const SymbolTable* InputSymbols() const = 0;
     71   // Return output label symbol table; return NULL if not specified
     72   virtual const SymbolTable* OutputSymbols() const = 0;
     73 
     74   // Return input label symbol table; return NULL if not specified
     75   virtual SymbolTable* MutableInputSymbols() = 0;
     76   // Return output label symbol table; return NULL if not specified
     77   virtual SymbolTable* MutableOutputSymbols() = 0;
     78 
     79   // Set input label symbol table; NULL signifies not unspecified
     80   virtual void SetInputSymbols(const SymbolTable* isyms) = 0;
     81   // Set output label symbol table; NULL signifies not unspecified
     82   virtual void SetOutputSymbols(const SymbolTable* osyms) = 0;
     83 
     84   // Get a copy of this MutableFst. See Fst<>::Copy() for further doc.
     85   virtual MutableFst<A> *Copy(bool safe = false) const = 0;
     86 
     87   // Read an MutableFst from an input stream; return NULL on error.
     88   static MutableFst<A> *Read(istream &strm, const FstReadOptions &opts) {
     89     FstReadOptions ropts(opts);
     90     FstHeader hdr;
     91     if (ropts.header)
     92       hdr = *opts.header;
     93     else {
     94       if (!hdr.Read(strm, opts.source))
     95         return 0;
     96       ropts.header = &hdr;
     97     }
     98     if (!(hdr.Properties() & kMutable)) {
     99       LOG(ERROR) << "MutableFst::Read: Not an MutableFst: " << ropts.source;
    100       return 0;
    101     }
    102     FstRegister<A> *registr = FstRegister<A>::GetRegister();
    103     const typename FstRegister<A>::Reader reader =
    104       registr->GetReader(hdr.FstType());
    105     if (!reader) {
    106       LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << hdr.FstType()
    107                  << "\" (arc type = \"" << A::Type()
    108                  << "\"): " << ropts.source;
    109       return 0;
    110     }
    111     Fst<A> *fst = reader(strm, ropts);
    112     if (!fst) return 0;
    113     return static_cast<MutableFst<A> *>(fst);
    114   }
    115 
    116   // Read a MutableFst from a file; return NULL on error.
    117   // Empty filename reads from standard input. If 'convert' is true,
    118   // convert to a mutable FST of type 'convert_type' if file is
    119   // a non-mutable FST.
    120   static MutableFst<A> *Read(const string &filename, bool convert = false,
    121                              const string &convert_type = "vector") {
    122     if (convert == false) {
    123       if (!filename.empty()) {
    124         ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    125         if (!strm) {
    126           LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename;
    127           return 0;
    128         }
    129         return Read(strm, FstReadOptions(filename));
    130       } else {
    131         return Read(std::cin, FstReadOptions("standard input"));
    132       }
    133     } else {  // Converts to 'convert_type' if not mutable.
    134       Fst<A> *ifst = Fst<A>::Read(filename);
    135       if (!ifst) return 0;
    136       if (ifst->Properties(kMutable, false)) {
    137         return static_cast<MutableFst *>(ifst);
    138       } else {
    139         Fst<A> *ofst = Convert(*ifst, convert_type);
    140         delete ifst;
    141         if (!ofst) return 0;
    142         if (!ofst->Properties(kMutable, false))
    143           LOG(ERROR) << "MutableFst: bad convert type: " << convert_type;
    144         return static_cast<MutableFst *>(ofst);
    145       }
    146     }
    147   }
    148 
    149   // For generic mutuble arc iterator construction; not normally called
    150   // directly by users.
    151   virtual void InitMutableArcIterator(StateId s,
    152                                       MutableArcIteratorData<A> *) = 0;
    153 };
    154 
    155 // Mutable arc iterator interface, templated on the Arc definition; used
    156 // for mutable Arc iterator specializations that are returned by
    157 // the InitMutableArcIterator MutableFst method.
    158 template <class A>
    159 class MutableArcIteratorBase : public ArcIteratorBase<A> {
    160  public:
    161   typedef A Arc;
    162 
    163   void SetValue(const A &arc) { SetValue_(arc); }  // Set current arc's content
    164 
    165  private:
    166   virtual void SetValue_(const A &arc) = 0;
    167 };
    168 
    169 template <class A>
    170 struct MutableArcIteratorData {
    171   MutableArcIteratorBase<A> *base;  // Specific iterator
    172 };
    173 
    174 // Generic mutable arc iterator, templated on the FST definition
    175 // - a wrapper around pointer to specific one.
    176 // Here is a typical use: \code
    177 //   for (MutableArcIterator<StdFst> aiter(&fst, s));
    178 //        !aiter.Done();
    179 //         aiter.Next()) {
    180 //     StdArc arc = aiter.Value();
    181 //     arc.ilabel = 7;
    182 //     aiter.SetValue(arc);
    183 //     ...
    184 //   } \endcode
    185 // This version requires function calls.
    186 template <class F>
    187 class MutableArcIterator {
    188  public:
    189   typedef F FST;
    190   typedef typename F::Arc Arc;
    191   typedef typename Arc::StateId StateId;
    192 
    193   MutableArcIterator(F *fst, StateId s) {
    194     fst->InitMutableArcIterator(s, &data_);
    195   }
    196   ~MutableArcIterator() { delete data_.base; }
    197 
    198   bool Done() const { return data_.base->Done(); }
    199   const Arc& Value() const { return data_.base->Value(); }
    200   void Next() { data_.base->Next(); }
    201   size_t Position() const { return data_.base->Position(); }
    202   void Reset() { data_.base->Reset(); }
    203   void Seek(size_t a) { data_.base->Seek(a); }
    204   void SetValue(const Arc &a) { data_.base->SetValue(a); }
    205   uint32 Flags() const { return data_.base->Flags(); }
    206   void SetFlags(uint32 f, uint32 m) {
    207     return data_.base->SetFlags(f, m);
    208   }
    209 
    210  private:
    211   MutableArcIteratorData<Arc> data_;
    212   DISALLOW_COPY_AND_ASSIGN(MutableArcIterator);
    213 };
    214 
    215 
    216 namespace internal {
    217 
    218 //  MutableFst<A> case - abstract methods.
    219 template <class A> inline
    220 typename A::Weight Final(const MutableFst<A> &fst, typename A::StateId s) {
    221   return fst.Final(s);
    222 }
    223 
    224 template <class A> inline
    225 ssize_t NumArcs(const MutableFst<A> &fst, typename A::StateId s) {
    226   return fst.NumArcs(s);
    227 }
    228 
    229 template <class A> inline
    230 ssize_t NumInputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
    231   return fst.NumInputEpsilons(s);
    232 }
    233 
    234 template <class A> inline
    235 ssize_t NumOutputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
    236   return fst.NumOutputEpsilons(s);
    237 }
    238 
    239 }  // namespace internal
    240 
    241 
    242 // A useful alias when using StdArc.
    243 typedef MutableFst<StdArc> StdMutableFst;
    244 
    245 
    246 // This is a helper class template useful for attaching a MutableFst
    247 // interface to its implementation, handling reference counting and
    248 // copy-on-write.
    249 template <class I, class F = MutableFst<typename I::Arc> >
    250 class ImplToMutableFst : public ImplToExpandedFst<I, F> {
    251  public:
    252   typedef typename I::Arc Arc;
    253   typedef typename Arc::Weight Weight;
    254   typedef typename Arc::StateId StateId;
    255 
    256   using ImplToFst<I, F>::GetImpl;
    257   using ImplToFst<I, F>::SetImpl;
    258 
    259   virtual void SetStart(StateId s) {
    260     MutateCheck();
    261     GetImpl()->SetStart(s);
    262   }
    263 
    264   virtual void SetFinal(StateId s, Weight w) {
    265     MutateCheck();
    266     GetImpl()->SetFinal(s, w);
    267   }
    268 
    269   virtual void SetProperties(uint64 props, uint64 mask) {
    270     // Can skip mutate check if extrinsic properties don't change,
    271     // since it is then safe to update all (shallow) copies
    272     uint64 exprops = kExtrinsicProperties & mask;
    273     if (GetImpl()->Properties(exprops) != (props & exprops))
    274       MutateCheck();
    275     GetImpl()->SetProperties(props, mask);
    276   }
    277 
    278   virtual StateId AddState() {
    279     MutateCheck();
    280     return GetImpl()->AddState();
    281   }
    282 
    283   virtual void AddArc(StateId s, const Arc &arc) {
    284     MutateCheck();
    285     GetImpl()->AddArc(s, arc);
    286   }
    287 
    288   virtual void DeleteStates(const vector<StateId> &dstates) {
    289     MutateCheck();
    290     GetImpl()->DeleteStates(dstates);
    291   }
    292 
    293   virtual void DeleteStates() {
    294     MutateCheck();
    295     GetImpl()->DeleteStates();
    296   }
    297 
    298   virtual void DeleteArcs(StateId s, size_t n) {
    299     MutateCheck();
    300     GetImpl()->DeleteArcs(s, n);
    301   }
    302 
    303   virtual void DeleteArcs(StateId s) {
    304     MutateCheck();
    305     GetImpl()->DeleteArcs(s);
    306   }
    307 
    308   virtual void ReserveStates(StateId s) {
    309     MutateCheck();
    310     GetImpl()->ReserveStates(s);
    311   }
    312 
    313   virtual void ReserveArcs(StateId s, size_t n) {
    314     MutateCheck();
    315     GetImpl()->ReserveArcs(s, n);
    316   }
    317 
    318   virtual const SymbolTable* InputSymbols() const {
    319     return GetImpl()->InputSymbols();
    320   }
    321 
    322   virtual const SymbolTable* OutputSymbols() const {
    323     return GetImpl()->OutputSymbols();
    324   }
    325 
    326   virtual SymbolTable* MutableInputSymbols() {
    327     MutateCheck();
    328     return GetImpl()->InputSymbols();
    329   }
    330 
    331   virtual SymbolTable* MutableOutputSymbols() {
    332     MutateCheck();
    333     return GetImpl()->OutputSymbols();
    334   }
    335 
    336   virtual void SetInputSymbols(const SymbolTable* isyms) {
    337     MutateCheck();
    338     GetImpl()->SetInputSymbols(isyms);
    339   }
    340 
    341   virtual void SetOutputSymbols(const SymbolTable* osyms) {
    342     MutateCheck();
    343     GetImpl()->SetOutputSymbols(osyms);
    344   }
    345 
    346  protected:
    347   ImplToMutableFst() : ImplToExpandedFst<I, F>() {}
    348 
    349   ImplToMutableFst(I *impl) : ImplToExpandedFst<I, F>(impl) {}
    350 
    351 
    352   ImplToMutableFst(const ImplToMutableFst<I, F> &fst)
    353       : ImplToExpandedFst<I, F>(fst) {}
    354 
    355   ImplToMutableFst(const ImplToMutableFst<I, F> &fst, bool safe)
    356       : ImplToExpandedFst<I, F>(fst, safe) {}
    357 
    358   void MutateCheck() {
    359     // Copy on write
    360     if (GetImpl()->RefCount() > 1)
    361       SetImpl(new I(*this));
    362   }
    363 
    364  private:
    365   // Disallow
    366   ImplToMutableFst<I, F>  &operator=(const ImplToMutableFst<I, F> &fst);
    367 
    368   ImplToMutableFst<I, F> &operator=(const Fst<Arc> &fst) {
    369     FSTERROR() << "ImplToMutableFst: Assignment operator disallowed";
    370     GetImpl()->SetProperties(kError, kError);
    371     return *this;
    372   }
    373 };
    374 
    375 
    376 }  // namespace fst
    377 
    378 #endif  // FST_LIB_MUTABLE_FST_H__
    379