Home | History | Annotate | Download | only in fst
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // All Rights Reserved.
     16 //
     17 // Author : Johan Schalkwyk
     18 //
     19 // \file
     20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
     21 
     22 #ifndef FST_LIB_SYMBOL_TABLE_H__
     23 #define FST_LIB_SYMBOL_TABLE_H__
     24 
     25 #include <cstring>
     26 #include <string>
     27 #include <utility>
     28 using std::pair; using std::make_pair;
     29 #include <vector>
     30 using std::vector;
     31 
     32 
     33 #include <fst/compat.h>
     34 #include <iostream>
     35 #include <fstream>
     36 #include <sstream>
     37 
     38 
     39 #include <map>
     40 
     41 DECLARE_bool(fst_compat_symbols);
     42 
     43 namespace fst {
     44 
     45 // WARNING: Reading via symbol table read options should
     46 //          not be used. This is a temporary work around for
     47 //          reading symbol ranges of previously stored symbol sets.
     48 struct SymbolTableReadOptions {
     49   SymbolTableReadOptions() { }
     50 
     51   SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
     52                          const string& source_)
     53       : string_hash_ranges(string_hash_ranges_),
     54         source(source_) { }
     55 
     56   vector<pair<int64, int64> > string_hash_ranges;
     57   string source;
     58 };
     59 
     60 struct SymbolTableTextOptions {
     61   SymbolTableTextOptions();
     62 
     63   bool allow_negative;
     64   string fst_field_separator;
     65 };
     66 
     67 class SymbolTableImpl {
     68  public:
     69   SymbolTableImpl(const string &name)
     70       : name_(name),
     71         available_key_(0),
     72         dense_key_limit_(0),
     73         check_sum_finalized_(false) {}
     74 
     75   explicit SymbolTableImpl(const SymbolTableImpl& impl)
     76       : name_(impl.name_),
     77         available_key_(0),
     78         dense_key_limit_(0),
     79         check_sum_finalized_(false) {
     80     for (size_t i = 0; i < impl.symbols_.size(); ++i) {
     81       AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
     82     }
     83   }
     84 
     85   ~SymbolTableImpl() {
     86     for (size_t i = 0; i < symbols_.size(); ++i)
     87       delete[] symbols_[i];
     88   }
     89 
     90   // TODO(johans): Add flag to specify whether the symbol
     91   //               should be indexed as string or int or both.
     92   int64 AddSymbol(const string& symbol, int64 key);
     93 
     94   int64 AddSymbol(const string& symbol) {
     95     int64 key = Find(symbol);
     96     return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
     97   }
     98 
     99   static SymbolTableImpl* ReadText(
    100       istream &strm, const string &name,
    101       const SymbolTableTextOptions &opts = SymbolTableTextOptions());
    102 
    103   static SymbolTableImpl* Read(istream &strm,
    104                                const SymbolTableReadOptions& opts);
    105 
    106   bool Write(ostream &strm) const;
    107 
    108   //
    109   // Return the string associated with the key. If the key is out of
    110   // range (<0, >max), return an empty string.
    111   string Find(int64 key) const {
    112     if (key >=0 && key < dense_key_limit_)
    113       return string(symbols_[key]);
    114 
    115     map<int64, const char*>::const_iterator it =
    116         key_map_.find(key);
    117     if (it == key_map_.end()) {
    118       return "";
    119     }
    120     return string(it->second);
    121   }
    122 
    123   //
    124   // Return the key associated with the symbol. If the symbol
    125   // does not exists, return SymbolTable::kNoSymbol.
    126   int64 Find(const string& symbol) const {
    127     return Find(symbol.c_str());
    128   }
    129 
    130   //
    131   // Return the key associated with the symbol. If the symbol
    132   // does not exists, return SymbolTable::kNoSymbol.
    133   int64 Find(const char* symbol) const {
    134     map<const char *, int64, StrCmp>::const_iterator it =
    135         symbol_map_.find(symbol);
    136     if (it == symbol_map_.end()) {
    137       return -1;
    138     }
    139     return it->second;
    140   }
    141 
    142   int64 GetNthKey(ssize_t pos) const {
    143     if ((pos < 0) || (pos >= symbols_.size())) return -1;
    144     else return Find(symbols_[pos]);
    145   }
    146 
    147   const string& Name() const { return name_; }
    148 
    149   int IncrRefCount() const {
    150     return ref_count_.Incr();
    151   }
    152   int DecrRefCount() const {
    153     return ref_count_.Decr();
    154   }
    155   int RefCount() const {
    156     return ref_count_.count();
    157   }
    158 
    159   string CheckSum() const {
    160     MaybeRecomputeCheckSum();
    161     return check_sum_string_;
    162   }
    163 
    164   string LabeledCheckSum() const {
    165     MaybeRecomputeCheckSum();
    166     return labeled_check_sum_string_;
    167   }
    168 
    169   int64 AvailableKey() const {
    170     return available_key_;
    171   }
    172 
    173   size_t NumSymbols() const {
    174     return symbols_.size();
    175   }
    176 
    177  private:
    178   // Recomputes the checksums (both of them) if we've had changes since the last
    179   // computation (i.e., if check_sum_finalized_ is false).
    180   // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
    181   // if the checksum is up-to-date (requiring no recomputation).
    182   void MaybeRecomputeCheckSum() const;
    183 
    184   struct StrCmp {
    185     bool operator()(const char *s1, const char *s2) const {
    186       return strcmp(s1, s2) < 0;
    187     }
    188   };
    189 
    190   string name_;
    191   int64 available_key_;
    192   int64 dense_key_limit_;
    193   vector<const char *> symbols_;
    194   map<int64, const char*> key_map_;
    195   map<const char *, int64, StrCmp> symbol_map_;
    196 
    197   mutable RefCounter ref_count_;
    198   mutable bool check_sum_finalized_;
    199   mutable string check_sum_string_;
    200   mutable string labeled_check_sum_string_;
    201   mutable Mutex check_sum_mutex_;
    202 };
    203 
    204 //
    205 // \class SymbolTable
    206 // \brief Symbol (string) to int and reverse mapping
    207 //
    208 // The SymbolTable implements the mappings of labels to strings and reverse.
    209 // SymbolTables are used to describe the alphabet of the input and output
    210 // labels for arcs in a Finite State Transducer.
    211 //
    212 // SymbolTables are reference counted and can therefore be shared across
    213 // multiple machines. For example a language model grammar G, with a
    214 // SymbolTable for the words in the language model can share this symbol
    215 // table with the lexical representation L o G.
    216 //
    217 class SymbolTable {
    218  public:
    219   static const int64 kNoSymbol = -1;
    220 
    221   // Construct symbol table with an unspecified name.
    222   SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
    223 
    224   // Construct symbol table with a unique name.
    225   SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
    226 
    227   // Create a reference counted copy.
    228   SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
    229     impl_->IncrRefCount();
    230   }
    231 
    232   // Derefence implentation object. When reference count hits 0, delete
    233   // implementation.
    234   virtual ~SymbolTable() {
    235     if (!impl_->DecrRefCount()) delete impl_;
    236   }
    237 
    238   // Copys the implemenation from one symbol table to another.
    239   void operator=(const SymbolTable &st) {
    240     if (impl_ != st.impl_) {
    241       st.impl_->IncrRefCount();
    242       if (!impl_->DecrRefCount()) delete impl_;
    243       impl_ = st.impl_;
    244     }
    245   }
    246 
    247   // Read an ascii representation of the symbol table from an istream. Pass a
    248   // name to give the resulting SymbolTable.
    249   static SymbolTable* ReadText(
    250       istream &strm, const string& name,
    251       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    252     SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
    253     if (!impl)
    254       return 0;
    255     else
    256       return new SymbolTable(impl);
    257   }
    258 
    259   // read an ascii representation of the symbol table
    260   static SymbolTable* ReadText(const string& filename,
    261       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    262     ifstream strm(filename.c_str(), ifstream::in);
    263     if (!strm) {
    264       LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
    265       return 0;
    266     }
    267     return ReadText(strm, filename, opts);
    268   }
    269 
    270 
    271   // WARNING: Reading via symbol table read options should
    272   //          not be used. This is a temporary work around.
    273   static SymbolTable* Read(istream &strm,
    274                            const SymbolTableReadOptions& opts) {
    275     SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
    276     if (!impl)
    277       return 0;
    278     else
    279       return new SymbolTable(impl);
    280   }
    281 
    282   // read a binary dump of the symbol table from a stream
    283   static SymbolTable* Read(istream &strm, const string& source) {
    284     SymbolTableReadOptions opts;
    285     opts.source = source;
    286     return Read(strm, opts);
    287   }
    288 
    289   // read a binary dump of the symbol table
    290   static SymbolTable* Read(const string& filename) {
    291     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    292     if (!strm) {
    293       LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
    294       return 0;
    295     }
    296     return Read(strm, filename);
    297   }
    298 
    299   //--------------------------------------------------------
    300   // Derivable Interface (final)
    301   //--------------------------------------------------------
    302   // create a reference counted copy
    303   virtual SymbolTable* Copy() const {
    304     return new SymbolTable(*this);
    305   }
    306 
    307   // Add a symbol with given key to table. A symbol table also
    308   // keeps track of the last available key (highest key value in
    309   // the symbol table).
    310   virtual int64 AddSymbol(const string& symbol, int64 key) {
    311     MutateCheck();
    312     return impl_->AddSymbol(symbol, key);
    313   }
    314 
    315   // Add a symbol to the table. The associated value key is automatically
    316   // assigned by the symbol table.
    317   virtual int64 AddSymbol(const string& symbol) {
    318     MutateCheck();
    319     return impl_->AddSymbol(symbol);
    320   }
    321 
    322   // Add another symbol table to this table. All key values will be offset
    323   // by the current available key (highest key value in the symbol table).
    324   // Note string symbols with the same key value with still have the same
    325   // key value after the symbol table has been merged, but a different
    326   // value. Adding symbol tables do not result in changes in the base table.
    327   virtual void AddTable(const SymbolTable& table);
    328 
    329   // return the name of the symbol table
    330   virtual const string& Name() const {
    331     return impl_->Name();
    332   }
    333 
    334   // Return the label-agnostic MD5 check-sum for this table.  All new symbols
    335   // added to the table will result in an updated checksum.
    336   // DEPRECATED.
    337   virtual string CheckSum() const {
    338     return impl_->CheckSum();
    339   }
    340 
    341   // Same as CheckSum(), but this returns an label-dependent version.
    342   virtual string LabeledCheckSum() const {
    343     return impl_->LabeledCheckSum();
    344   }
    345 
    346   virtual bool Write(ostream &strm) const {
    347     return impl_->Write(strm);
    348   }
    349 
    350   bool Write(const string& filename) const {
    351     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
    352     if (!strm) {
    353       LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
    354       return false;
    355     }
    356     return Write(strm);
    357   }
    358 
    359   // Dump an ascii text representation of the symbol table via a stream
    360   virtual bool WriteText(
    361       ostream &strm,
    362       const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
    363 
    364   // Dump an ascii text representation of the symbol table
    365   bool WriteText(const string& filename) const {
    366     ofstream strm(filename.c_str());
    367     if (!strm) {
    368       LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
    369       return false;
    370     }
    371     return WriteText(strm);
    372   }
    373 
    374   // Return the string associated with the key. If the key is out of
    375   // range (<0, >max), log error and return an empty string.
    376   virtual string Find(int64 key) const {
    377     return impl_->Find(key);
    378   }
    379 
    380   // Return the key associated with the symbol. If the symbol
    381   // does not exists, log error and  return SymbolTable::kNoSymbol
    382   virtual int64 Find(const string& symbol) const {
    383     return impl_->Find(symbol);
    384   }
    385 
    386   // Return the key associated with the symbol. If the symbol
    387   // does not exists, log error and  return SymbolTable::kNoSymbol
    388   virtual int64 Find(const char* symbol) const {
    389     return impl_->Find(symbol);
    390   }
    391 
    392   // Return the current available key (i.e highest key number+1) in
    393   // the symbol table
    394   virtual int64 AvailableKey(void) const {
    395     return impl_->AvailableKey();
    396   }
    397 
    398   // Return the current number of symbols in table (not necessarily
    399   // equal to AvailableKey())
    400   virtual size_t NumSymbols(void) const {
    401     return impl_->NumSymbols();
    402   }
    403 
    404   virtual int64 GetNthKey(ssize_t pos) const {
    405     return impl_->GetNthKey(pos);
    406   }
    407 
    408  private:
    409   explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
    410 
    411   void MutateCheck() {
    412     // Copy on write
    413     if (impl_->RefCount() > 1) {
    414       impl_->DecrRefCount();
    415       impl_ = new SymbolTableImpl(*impl_);
    416     }
    417   }
    418 
    419   const SymbolTableImpl* Impl() const {
    420     return impl_;
    421   }
    422 
    423  private:
    424   SymbolTableImpl* impl_;
    425 };
    426 
    427 
    428 //
    429 // \class SymbolTableIterator
    430 // \brief Iterator class for symbols in a symbol table
    431 class SymbolTableIterator {
    432  public:
    433   SymbolTableIterator(const SymbolTable& table)
    434       : table_(table),
    435         pos_(0),
    436         nsymbols_(table.NumSymbols()),
    437         key_(table.GetNthKey(0)) { }
    438 
    439   ~SymbolTableIterator() { }
    440 
    441   // is iterator done
    442   bool Done(void) {
    443     return (pos_ == nsymbols_);
    444   }
    445 
    446   // return the Value() of the current symbol (int64 key)
    447   int64 Value(void) {
    448     return key_;
    449   }
    450 
    451   // return the string of the current symbol
    452   string Symbol(void) {
    453     return table_.Find(key_);
    454   }
    455 
    456   // advance iterator forward
    457   void Next(void) {
    458     ++pos_;
    459     if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
    460   }
    461 
    462   // reset iterator
    463   void Reset(void) {
    464     pos_ = 0;
    465     key_ = table_.GetNthKey(0);
    466   }
    467 
    468  private:
    469   const SymbolTable& table_;
    470   ssize_t pos_;
    471   size_t nsymbols_;
    472   int64 key_;
    473 };
    474 
    475 
    476 // Tests compatibilty between two sets of symbol tables
    477 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
    478                           bool warning = true) {
    479   if (!FLAGS_fst_compat_symbols) {
    480     return true;
    481   } else if (!syms1 && !syms2) {
    482     return true;
    483   } else if (syms1 && !syms2) {
    484     if (warning)
    485       LOG(WARNING) <<
    486           "CompatSymbols: first symbol table present but second missing";
    487     return false;
    488   } else if (!syms1 && syms2) {
    489     if (warning)
    490       LOG(WARNING) <<
    491           "CompatSymbols: second symbol table present but first missing";
    492     return false;
    493   } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
    494     if (warning)
    495       LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
    496     return false;
    497   } else {
    498     return true;
    499   }
    500 }
    501 
    502 
    503 // Relabels a symbol table as specified by the input vector of pairs
    504 // (old label, new label). The new symbol table only retains symbols
    505 // for which a relabeling is *explicitely* specified.
    506 // TODO(allauzen): consider adding options to allow for some form
    507 // of implicit identity relabeling.
    508 template <class Label>
    509 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
    510                                 const vector<pair<Label, Label> > &pairs) {
    511   SymbolTable *new_table = new SymbolTable(
    512       table->Name().empty() ? string() :
    513       (string("relabeled_") + table->Name()));
    514 
    515   for (size_t i = 0; i < pairs.size(); ++i)
    516     new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
    517 
    518   return new_table;
    519 }
    520 
    521 // Symbol Table Serialization
    522 inline void SymbolTableToString(const SymbolTable *table, string *result) {
    523   ostringstream ostrm;
    524   table->Write(ostrm);
    525   *result = ostrm.str();
    526 }
    527 
    528 inline SymbolTable *StringToSymbolTable(const string &s) {
    529   istringstream istrm(s);
    530   return SymbolTable::Read(istrm, SymbolTableReadOptions());
    531 }
    532 
    533 
    534 
    535 }  // namespace fst
    536 
    537 #endif  // FST_LIB_SYMBOL_TABLE_H__
    538