Home | History | Annotate | Download | only in fst
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // All Rights Reserved.
     16 //
     17 // Author : Johan Schalkwyk
     18 //
     19 // \file
     20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
     21 
     22 #ifndef FST_LIB_SYMBOL_TABLE_H__
     23 #define FST_LIB_SYMBOL_TABLE_H__
     24 
     25 #include <cstring>
     26 #include <string>
     27 #include <utility>
     28 using std::pair; using std::make_pair;
     29 #include <vector>
     30 using std::vector;
     31 
     32 
     33 #include <fst/compat.h>
     34 #include <iostream>
     35 #include <fstream>
     36 
     37 
     38 #include <map>
     39 
     40 DECLARE_bool(fst_compat_symbols);
     41 
     42 namespace fst {
     43 
     44 // WARNING: Reading via symbol table read options should
     45 //          not be used. This is a temporary work around for
     46 //          reading symbol ranges of previously stored symbol sets.
     47 struct SymbolTableReadOptions {
     48   SymbolTableReadOptions() { }
     49 
     50   SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
     51                          const string& source_)
     52       : string_hash_ranges(string_hash_ranges_),
     53         source(source_) { }
     54 
     55   vector<pair<int64, int64> > string_hash_ranges;
     56   string source;
     57 };
     58 
     59 class SymbolTableImpl {
     60  public:
     61   SymbolTableImpl(const string &name)
     62       : name_(name),
     63         available_key_(0),
     64         dense_key_limit_(0),
     65         check_sum_finalized_(false) {}
     66 
     67   explicit SymbolTableImpl(const SymbolTableImpl& impl)
     68       : name_(impl.name_),
     69         available_key_(0),
     70         dense_key_limit_(0),
     71         check_sum_finalized_(false) {
     72     for (size_t i = 0; i < impl.symbols_.size(); ++i) {
     73       AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
     74     }
     75   }
     76 
     77   ~SymbolTableImpl() {
     78     for (size_t i = 0; i < symbols_.size(); ++i)
     79       delete[] symbols_[i];
     80   }
     81 
     82   // TODO(johans): Add flag to specify whether the symbol
     83   //               should be indexed as string or int or both.
     84   int64 AddSymbol(const string& symbol, int64 key);
     85 
     86   int64 AddSymbol(const string& symbol) {
     87     int64 key = Find(symbol);
     88     return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
     89   }
     90 
     91   static SymbolTableImpl* ReadText(istream &strm,
     92                                    const string &name,
     93                                    bool allow_negative = false);
     94 
     95   static SymbolTableImpl* Read(istream &strm,
     96                                const SymbolTableReadOptions& opts);
     97 
     98   bool Write(ostream &strm) const;
     99 
    100   //
    101   // Return the string associated with the key. If the key is out of
    102   // range (<0, >max), return an empty string.
    103   string Find(int64 key) const {
    104     if (key >=0 && key < dense_key_limit_)
    105       return string(symbols_[key]);
    106 
    107     map<int64, const char*>::const_iterator it =
    108         key_map_.find(key);
    109     if (it == key_map_.end()) {
    110       return "";
    111     }
    112     return string(it->second);
    113   }
    114 
    115   //
    116   // Return the key associated with the symbol. If the symbol
    117   // does not exists, return SymbolTable::kNoSymbol.
    118   int64 Find(const string& symbol) const {
    119     return Find(symbol.c_str());
    120   }
    121 
    122   //
    123   // Return the key associated with the symbol. If the symbol
    124   // does not exists, return SymbolTable::kNoSymbol.
    125   int64 Find(const char* symbol) const {
    126     map<const char *, int64, StrCmp>::const_iterator it =
    127         symbol_map_.find(symbol);
    128     if (it == symbol_map_.end()) {
    129       return -1;
    130     }
    131     return it->second;
    132   }
    133 
    134   int64 GetNthKey(ssize_t pos) const {
    135     if ((pos < 0) || (pos >= symbols_.size())) return -1;
    136     else return Find(symbols_[pos]);
    137   }
    138 
    139   const string& Name() const { return name_; }
    140 
    141   int IncrRefCount() const {
    142     return ref_count_.Incr();
    143   }
    144   int DecrRefCount() const {
    145     return ref_count_.Decr();
    146   }
    147   int RefCount() const {
    148     return ref_count_.count();
    149   }
    150 
    151   string CheckSum() const {
    152     MutexLock check_sum_lock(&check_sum_mutex_);
    153     MaybeRecomputeCheckSum();
    154     return check_sum_string_;
    155   }
    156 
    157   string LabeledCheckSum() const {
    158     MutexLock check_sum_lock(&check_sum_mutex_);
    159     MaybeRecomputeCheckSum();
    160     return labeled_check_sum_string_;
    161   }
    162 
    163   int64 AvailableKey() const {
    164     return available_key_;
    165   }
    166 
    167   size_t NumSymbols() const {
    168     return symbols_.size();
    169   }
    170 
    171  private:
    172   // Recomputes the checksums (both of them) if we've had changes since the last
    173   // computation (i.e., if check_sum_finalized_ is false).
    174   void MaybeRecomputeCheckSum() const;
    175 
    176   struct StrCmp {
    177     bool operator()(const char *s1, const char *s2) const {
    178       return strcmp(s1, s2) < 0;
    179     }
    180   };
    181 
    182   string name_;
    183   int64 available_key_;
    184   int64 dense_key_limit_;
    185   vector<const char *> symbols_;
    186   map<int64, const char*> key_map_;
    187   map<const char *, int64, StrCmp> symbol_map_;
    188 
    189   mutable RefCounter ref_count_;
    190   mutable bool check_sum_finalized_;
    191   mutable CheckSummer check_sum_;
    192   mutable CheckSummer labeled_check_sum_;
    193   mutable string check_sum_string_;
    194   mutable string labeled_check_sum_string_;
    195   mutable Mutex check_sum_mutex_;
    196 };
    197 
    198 //
    199 // \class SymbolTable
    200 // \brief Symbol (string) to int and reverse mapping
    201 //
    202 // The SymbolTable implements the mappings of labels to strings and reverse.
    203 // SymbolTables are used to describe the alphabet of the input and output
    204 // labels for arcs in a Finite State Transducer.
    205 //
    206 // SymbolTables are reference counted and can therefore be shared across
    207 // multiple machines. For example a language model grammar G, with a
    208 // SymbolTable for the words in the language model can share this symbol
    209 // table with the lexical representation L o G.
    210 //
    211 class SymbolTable {
    212  public:
    213   static const int64 kNoSymbol = -1;
    214 
    215   // Construct symbol table with a unique name.
    216   SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
    217 
    218   // Create a reference counted copy.
    219   SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
    220     impl_->IncrRefCount();
    221   }
    222 
    223   // Derefence implentation object. When reference count hits 0, delete
    224   // implementation.
    225   virtual ~SymbolTable() {
    226     if (!impl_->DecrRefCount()) delete impl_;
    227   }
    228 
    229   // Read an ascii representation of the symbol table from an istream. Pass a
    230   // name to give the resulting SymbolTable.
    231   static SymbolTable* ReadText(istream &strm,
    232                                const string& name,
    233                                bool allow_negative = false) {
    234     SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm,
    235                                                       name,
    236                                                       allow_negative);
    237     if (!impl)
    238       return 0;
    239     else
    240       return new SymbolTable(impl);
    241   }
    242 
    243   // read an ascii representation of the symbol table
    244   static SymbolTable* ReadText(const string& filename,
    245                                bool allow_negative = false) {
    246     ifstream strm(filename.c_str(), ifstream::in);
    247     if (!strm) {
    248       LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
    249       return 0;
    250     }
    251     return ReadText(strm, filename, allow_negative);
    252   }
    253 
    254 
    255   // WARNING: Reading via symbol table read options should
    256   //          not be used. This is a temporary work around.
    257   static SymbolTable* Read(istream &strm,
    258                            const SymbolTableReadOptions& opts) {
    259     SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
    260     if (!impl)
    261       return 0;
    262     else
    263       return new SymbolTable(impl);
    264   }
    265 
    266   // read a binary dump of the symbol table from a stream
    267   static SymbolTable* Read(istream &strm, const string& source) {
    268     SymbolTableReadOptions opts;
    269     opts.source = source;
    270     return Read(strm, opts);
    271   }
    272 
    273   // read a binary dump of the symbol table
    274   static SymbolTable* Read(const string& filename) {
    275     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    276     if (!strm) {
    277       LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
    278       return 0;
    279     }
    280     return Read(strm, filename);
    281   }
    282 
    283   //--------------------------------------------------------
    284   // Derivable Interface (final)
    285   //--------------------------------------------------------
    286   // create a reference counted copy
    287   virtual SymbolTable* Copy() const {
    288     return new SymbolTable(*this);
    289   }
    290 
    291   // Add a symbol with given key to table. A symbol table also
    292   // keeps track of the last available key (highest key value in
    293   // the symbol table).
    294   virtual int64 AddSymbol(const string& symbol, int64 key) {
    295     MutateCheck();
    296     return impl_->AddSymbol(symbol, key);
    297   }
    298 
    299   // Add a symbol to the table. The associated value key is automatically
    300   // assigned by the symbol table.
    301   virtual int64 AddSymbol(const string& symbol) {
    302     MutateCheck();
    303     return impl_->AddSymbol(symbol);
    304   }
    305 
    306   // Add another symbol table to this table. All key values will be offset
    307   // by the current available key (highest key value in the symbol table).
    308   // Note string symbols with the same key value with still have the same
    309   // key value after the symbol table has been merged, but a different
    310   // value. Adding symbol tables do not result in changes in the base table.
    311   virtual void AddTable(const SymbolTable& table);
    312 
    313   // return the name of the symbol table
    314   virtual const string& Name() const {
    315     return impl_->Name();
    316   }
    317 
    318   // Return the label-agnostic MD5 check-sum for this table.  All new symbols
    319   // added to the table will result in an updated checksum.
    320   // DEPRECATED.
    321   virtual string CheckSum() const {
    322     return impl_->CheckSum();
    323   }
    324 
    325   // Same as CheckSum(), but this returns an label-dependent version.
    326   virtual string LabeledCheckSum() const {
    327     return impl_->LabeledCheckSum();
    328   }
    329 
    330   virtual bool Write(ostream &strm) const {
    331     return impl_->Write(strm);
    332   }
    333 
    334   bool Write(const string& filename) const {
    335     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
    336     if (!strm) {
    337       LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
    338       return false;
    339     }
    340     return Write(strm);
    341   }
    342 
    343   // Dump an ascii text representation of the symbol table via a stream
    344   virtual bool WriteText(ostream &strm) const;
    345 
    346   // Dump an ascii text representation of the symbol table
    347   bool WriteText(const string& filename) const {
    348     ofstream strm(filename.c_str());
    349     if (!strm) {
    350       LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
    351       return false;
    352     }
    353     return WriteText(strm);
    354   }
    355 
    356   // Return the string associated with the key. If the key is out of
    357   // range (<0, >max), log error and return an empty string.
    358   virtual string Find(int64 key) const {
    359     return impl_->Find(key);
    360   }
    361 
    362   // Return the key associated with the symbol. If the symbol
    363   // does not exists, log error and  return SymbolTable::kNoSymbol
    364   virtual int64 Find(const string& symbol) const {
    365     return impl_->Find(symbol);
    366   }
    367 
    368   // Return the key associated with the symbol. If the symbol
    369   // does not exists, log error and  return SymbolTable::kNoSymbol
    370   virtual int64 Find(const char* symbol) const {
    371     return impl_->Find(symbol);
    372   }
    373 
    374   // Return the current available key (i.e highest key number+1) in
    375   // the symbol table
    376   virtual int64 AvailableKey(void) const {
    377     return impl_->AvailableKey();
    378   }
    379 
    380   // Return the current number of symbols in table (not necessarily
    381   // equal to AvailableKey())
    382   virtual size_t NumSymbols(void) const {
    383     return impl_->NumSymbols();
    384   }
    385 
    386   virtual int64 GetNthKey(ssize_t pos) const {
    387     return impl_->GetNthKey(pos);
    388   }
    389 
    390  private:
    391   explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
    392 
    393   void MutateCheck() {
    394     // Copy on write
    395     if (impl_->RefCount() > 1) {
    396       impl_->DecrRefCount();
    397       impl_ = new SymbolTableImpl(*impl_);
    398     }
    399   }
    400 
    401   const SymbolTableImpl* Impl() const {
    402     return impl_;
    403   }
    404 
    405  private:
    406   SymbolTableImpl* impl_;
    407 
    408   void operator=(const SymbolTable &table);  // disallow
    409 };
    410 
    411 
    412 //
    413 // \class SymbolTableIterator
    414 // \brief Iterator class for symbols in a symbol table
    415 class SymbolTableIterator {
    416  public:
    417   SymbolTableIterator(const SymbolTable& table)
    418       : table_(table),
    419         pos_(0),
    420         nsymbols_(table.NumSymbols()),
    421         key_(table.GetNthKey(0)) { }
    422 
    423   ~SymbolTableIterator() { }
    424 
    425   // is iterator done
    426   bool Done(void) {
    427     return (pos_ == nsymbols_);
    428   }
    429 
    430   // return the Value() of the current symbol (int64 key)
    431   int64 Value(void) {
    432     return key_;
    433   }
    434 
    435   // return the string of the current symbol
    436   string Symbol(void) {
    437     return table_.Find(key_);
    438   }
    439 
    440   // advance iterator forward
    441   void Next(void) {
    442     ++pos_;
    443     if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
    444   }
    445 
    446   // reset iterator
    447   void Reset(void) {
    448     pos_ = 0;
    449     key_ = table_.GetNthKey(0);
    450   }
    451 
    452  private:
    453   const SymbolTable& table_;
    454   ssize_t pos_;
    455   size_t nsymbols_;
    456   int64 key_;
    457 };
    458 
    459 
    460 // Tests compatibilty between two sets of symbol tables
    461 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
    462                           bool warning = true) {
    463   if (!FLAGS_fst_compat_symbols) {
    464     return true;
    465   } else if (!syms1 && !syms2) {
    466     return true;
    467   } else if (syms1 && !syms2) {
    468     if (warning)
    469       LOG(WARNING) <<
    470           "CompatSymbols: first symbol table present but second missing";
    471     return false;
    472   } else if (!syms1 && syms2) {
    473     if (warning)
    474       LOG(WARNING) <<
    475           "CompatSymbols: second symbol table present but first missing";
    476     return false;
    477   } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
    478     if (warning)
    479       LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
    480     return false;
    481   } else {
    482     return true;
    483   }
    484 }
    485 
    486 
    487 // Relabels a symbol table as specified by the input vector of pairs
    488 // (old label, new label). The new symbol table only retains symbols
    489 // for which a relabeling is *explicitely* specified.
    490 // TODO(allauzen): consider adding options to allow for some form
    491 // of implicit identity relabeling.
    492 template <class Label>
    493 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
    494                                 const vector<pair<Label, Label> > &pairs) {
    495   SymbolTable *new_table = new SymbolTable(
    496       table->Name().empty() ? string() :
    497       (string("relabeled_") + table->Name()));
    498 
    499   for (size_t i = 0; i < pairs.size(); ++i)
    500     new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
    501 
    502   return new_table;
    503 }
    504 
    505 }  // namespace fst
    506 
    507 #endif  // FST_LIB_SYMBOL_TABLE_H__
    508