1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: jpr (at) google.com (Jake Ratkiewicz) 16 17 #ifndef FST_SCRIPT_FST_CLASS_H_ 18 #define FST_SCRIPT_FST_CLASS_H_ 19 20 #include <string> 21 22 #include <fst/fst.h> 23 #include <fst/mutable-fst.h> 24 #include <fst/vector-fst.h> 25 #include <iostream> 26 #include <fstream> 27 #include <sstream> 28 29 // Classes to support "boxing" all existing types of FST arcs in a single 30 // FstClass which hides the arc types. This allows clients to load 31 // and work with FSTs without knowing the arc type. 32 33 // These classes are only recommended for use in high-level scripting 34 // applications. Most users should use the lower-level templated versions 35 // corresponding to these classes. 36 37 namespace fst { 38 namespace script { 39 40 // 41 // Abstract base class defining the set of functionalities implemented 42 // in all impls, and passed through by all bases Below FstClassBase 43 // the class hierarchy bifurcates; FstClassImplBase serves as the base 44 // class for all implementations (of which FstClassImpl is currently 45 // the only one) and FstClass serves as the base class for all 46 // interfaces. 47 // 48 class FstClassBase { 49 public: 50 virtual const string &ArcType() const = 0; 51 virtual const string &FstType() const = 0; 52 virtual const string &WeightType() const = 0; 53 virtual const SymbolTable *InputSymbols() const = 0; 54 virtual const SymbolTable *OutputSymbols() const = 0; 55 virtual bool Write(const string& fname) const = 0; 56 virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0; 57 virtual uint64 Properties(uint64 mask, bool test) const = 0; 58 virtual ~FstClassBase() { } 59 }; 60 61 class FstClassImplBase : public FstClassBase { 62 public: 63 virtual FstClassImplBase *Copy() = 0; 64 virtual void SetInputSymbols(SymbolTable *is) = 0; 65 virtual void SetOutputSymbols(SymbolTable *is) = 0; 66 virtual ~FstClassImplBase() { } 67 }; 68 69 70 // 71 // CONTAINER CLASS 72 // Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc> 73 // pointer refers to a special kind of FST (e.g. a MutableFst) is 74 // known by the type of interface class that owns the pointer to this 75 // container. 76 // 77 78 template<class Arc> 79 class FstClassImpl : public FstClassImplBase { 80 public: 81 explicit FstClassImpl(Fst<Arc> *impl, 82 bool should_own = false) : 83 impl_(should_own ? impl : impl->Copy()) { } 84 85 explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) { } 86 87 virtual const string &ArcType() const { 88 return Arc::Type(); 89 } 90 91 virtual const string &FstType() const { 92 return impl_->Type(); 93 } 94 95 virtual const string &WeightType() const { 96 return Arc::Weight::Type(); 97 } 98 99 virtual const SymbolTable *InputSymbols() const { 100 return impl_->InputSymbols(); 101 } 102 103 virtual const SymbolTable *OutputSymbols() const { 104 return impl_->OutputSymbols(); 105 } 106 107 // Warning: calling this method casts the FST to a mutable FST. 108 virtual void SetInputSymbols(SymbolTable *is) { 109 static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is); 110 } 111 112 // Warning: calling this method casts the FST to a mutable FST. 113 virtual void SetOutputSymbols(SymbolTable *os) { 114 static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os); 115 } 116 117 virtual bool Write(const string &fname) const { 118 return impl_->Write(fname); 119 } 120 121 virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { 122 return impl_->Write(ostr, opts); 123 } 124 125 virtual uint64 Properties(uint64 mask, bool test) const { 126 return impl_->Properties(mask, test); 127 } 128 129 virtual ~FstClassImpl() { delete impl_; } 130 131 Fst<Arc> *GetImpl() const { return impl_; } 132 133 Fst<Arc> *GetImpl() { return impl_; } 134 135 virtual FstClassImpl *Copy() { 136 return new FstClassImpl<Arc>(impl_); 137 } 138 139 private: 140 Fst<Arc> *impl_; 141 }; 142 143 // 144 // BASE CLASS DEFINITIONS 145 // 146 147 class MutableFstClass; 148 149 class FstClass : public FstClassBase { 150 public: 151 template<class Arc> 152 static FstClass *Read(istream &stream, 153 const FstReadOptions &opts) { 154 if (!opts.header) { 155 FSTERROR() << "FstClass::Read: options header not specified"; 156 return 0; 157 } 158 const FstHeader &hdr = *opts.header; 159 160 if (hdr.Properties() & kMutable) { 161 return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts); 162 } else { 163 return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts); 164 } 165 } 166 167 FstClass() : impl_(NULL) { 168 } 169 170 template<class Arc> 171 explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) { 172 } 173 174 FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } 175 176 FstClass &operator=(const FstClass &other) { 177 delete impl_; 178 impl_ = other.impl_->Copy(); 179 return *this; 180 } 181 182 static FstClass *Read(const string &fname); 183 184 static FstClass *Read(istream &istr, const string &source); 185 186 virtual const string &ArcType() const { 187 return impl_->ArcType(); 188 } 189 190 virtual const string& FstType() const { 191 return impl_->FstType(); 192 } 193 194 virtual const SymbolTable *InputSymbols() const { 195 return impl_->InputSymbols(); 196 } 197 198 virtual const SymbolTable *OutputSymbols() const { 199 return impl_->OutputSymbols(); 200 } 201 202 virtual const string& WeightType() const { 203 return impl_->WeightType(); 204 } 205 206 virtual bool Write(const string &fname) const { 207 return impl_->Write(fname); 208 } 209 210 virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { 211 return impl_->Write(ostr, opts); 212 } 213 214 virtual uint64 Properties(uint64 mask, bool test) const { 215 return impl_->Properties(mask, test); 216 } 217 218 template<class Arc> 219 const Fst<Arc> *GetFst() const { 220 if (Arc::Type() != ArcType()) { 221 return NULL; 222 } else { 223 FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_); 224 return typed_impl->GetImpl(); 225 } 226 } 227 228 virtual ~FstClass() { delete impl_; } 229 230 // These methods are required by IO registration 231 template<class Arc> 232 static FstClassImplBase *Convert(const FstClass &other) { 233 LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass."; 234 return 0; 235 } 236 237 template<class Arc> 238 static FstClassImplBase *Create() { 239 LOG(ERROR) << "Doesn't make sense to create an FstClass with a " 240 << "particular arc type."; 241 return 0; 242 } 243 244 245 protected: 246 explicit FstClass(FstClassImplBase *impl) : impl_(impl) { } 247 248 // Generic template method for reading an arc-templated FST of type 249 // UnderlyingT, and returning it wrapped as FstClassT, with appropriate 250 // error checking. Called from arc-templated Read() static methods. 251 template<class FstClassT, class UnderlyingT> 252 static FstClassT* ReadTypedFst(istream &stream, 253 const FstReadOptions &opts) { 254 UnderlyingT *u = UnderlyingT::Read(stream, opts); 255 if (!u) { 256 return 0; 257 } else { 258 FstClassT *r = new FstClassT(*u); 259 delete u; 260 return r; 261 } 262 } 263 264 FstClassImplBase *GetImpl() const { return impl_; } 265 266 FstClassImplBase *GetImpl() { return impl_; } 267 268 // friend ostream &operator<<(ostream&, const FstClass&); 269 270 private: 271 FstClassImplBase *impl_; 272 }; 273 274 // 275 // Specific types of FstClass with special properties 276 // 277 278 class MutableFstClass : public FstClass { 279 public: 280 template<class Arc> 281 explicit MutableFstClass(const MutableFst<Arc> &fst) : 282 FstClass(fst) { } 283 284 template<class Arc> 285 MutableFst<Arc> *GetMutableFst() { 286 Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>()); 287 MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst); 288 289 return mfst; 290 } 291 292 template<class Arc> 293 static MutableFstClass *Read(istream &stream, 294 const FstReadOptions &opts) { 295 MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts); 296 if (!mfst) { 297 return 0; 298 } else { 299 MutableFstClass *retval = new MutableFstClass(*mfst); 300 delete mfst; 301 return retval; 302 } 303 } 304 305 virtual bool Write(const string &fname) const { 306 return GetImpl()->Write(fname); 307 } 308 309 virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { 310 return GetImpl()->Write(ostr, opts); 311 } 312 313 static MutableFstClass *Read(const string &fname, bool convert = false); 314 315 virtual void SetInputSymbols(SymbolTable *is) { 316 GetImpl()->SetInputSymbols(is); 317 } 318 319 virtual void SetOutputSymbols(SymbolTable *os) { 320 GetImpl()->SetOutputSymbols(os); 321 } 322 323 // These methods are required by IO registration 324 template<class Arc> 325 static FstClassImplBase *Convert(const FstClass &other) { 326 LOG(ERROR) << "Doesn't make sense to convert any class to type " 327 << "MutableFstClass."; 328 return 0; 329 } 330 331 template<class Arc> 332 static FstClassImplBase *Create() { 333 LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a " 334 << "particular arc type."; 335 return 0; 336 } 337 338 protected: 339 explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { } 340 }; 341 342 343 class VectorFstClass : public MutableFstClass { 344 public: 345 explicit VectorFstClass(const FstClass &other); 346 explicit VectorFstClass(const string &arc_type); 347 348 template<class Arc> 349 explicit VectorFstClass(const VectorFst<Arc> &fst) : 350 MutableFstClass(fst) { } 351 352 template<class Arc> 353 static VectorFstClass *Read(istream &stream, 354 const FstReadOptions &opts) { 355 VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts); 356 if (!vfst) { 357 return 0; 358 } else { 359 VectorFstClass *retval = new VectorFstClass(*vfst); 360 delete vfst; 361 return retval; 362 } 363 } 364 365 static VectorFstClass *Read(const string &fname); 366 367 // Converter / creator for known arc types 368 template<class Arc> 369 static FstClassImplBase *Convert(const FstClass &other) { 370 return new FstClassImpl<Arc>(new VectorFst<Arc>( 371 *other.GetFst<Arc>()), true); 372 } 373 374 template<class Arc> 375 static FstClassImplBase *Create() { 376 return new FstClassImpl<Arc>(new VectorFst<Arc>(), true); 377 } 378 }; 379 380 } // namespace script 381 } // namespace fst 382 #endif // FST_SCRIPT_FST_CLASS_H_ 383