Home | History | Annotate | Download | only in lib
      1 // symbol-table.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
     18 
     19 #ifndef FST_LIB_SYMBOL_TABLE_H__
     20 #define FST_LIB_SYMBOL_TABLE_H__
     21 
     22 #include <fstream>
     23 #include <iostream>
     24 #include <string>
     25 #include <unordered_map>
     26 #include <vector>
     27 
     28 #include "fst/lib/compat.h"
     29 
     30 
     31 
     32 DECLARE_bool(fst_compat_symbols);
     33 
     34 namespace fst {
     35 
     36 class SymbolTableImpl {
     37   friend class SymbolTableIterator;
     38  public:
     39   SymbolTableImpl(const string &name)
     40       : name_(name), available_key_(0), ref_count_(1),
     41         check_sum_finalized_(false) {}
     42   ~SymbolTableImpl() {
     43     for (size_t i = 0; i < symbols_.size(); ++i)
     44       delete[] symbols_[i];
     45   }
     46 
     47   int64 AddSymbol(const string& symbol, int64 key);
     48 
     49   int64 AddSymbol(const string& symbol) {
     50     int64 key = Find(symbol);
     51     return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
     52   }
     53 
     54   void AddTable(SymbolTableImpl* table) {
     55     for (size_t i = 0; i < table->symbols_.size(); ++i) {
     56       AddSymbol(table->symbols_[i]);
     57     }
     58   }
     59 
     60   static SymbolTableImpl* ReadText(const string& filename);
     61 
     62   static SymbolTableImpl* Read(istream &strm, const string& source);
     63 
     64   bool Write(ostream &strm) const;
     65 
     66   bool WriteText(ostream &strm) const;
     67 
     68   //
     69   // Return the string associated with the key. If the key is out of
     70   // range (<0, >max), return an empty string.
     71   string Find(int64 key) const {
     72     std::unordered_map<int64, string>::const_iterator it =
     73       key_map_.find(key);
     74     if (it == key_map_.end()) {
     75       return "";
     76     }
     77     return it->second;
     78   }
     79 
     80   //
     81   // Return the key associated with the symbol. If the symbol
     82   // does not exists, return -1.
     83   int64 Find(const string& symbol) const {
     84     return Find(symbol.c_str());
     85   }
     86 
     87   //
     88   // Return the key associated with the symbol. If the symbol
     89   // does not exists, return -1.
     90   int64 Find(const char* symbol) const {
     91     unordered_map<string, int64>::const_iterator it =
     92       symbol_map_.find(symbol);
     93     if (it == symbol_map_.end()) {
     94       return -1;
     95     }
     96     return it->second;
     97   }
     98 
     99   const string& Name() const { return name_; }
    100 
    101   int IncrRefCount() const {
    102     return ++ref_count_;
    103   }
    104   int DecrRefCount() const {
    105     return --ref_count_;
    106   }
    107 
    108   string CheckSum() const {
    109     if (!check_sum_finalized_) {
    110       RecomputeCheckSum();
    111       check_sum_string_ = check_sum_.Digest();
    112     }
    113     return check_sum_string_;
    114   }
    115 
    116   int64 AvailableKey() const {
    117     return available_key_;
    118   }
    119 
    120   // private support methods
    121  private:
    122   void RecomputeCheckSum() const;
    123   static SymbolTableImpl* Read1(istream &, const string &);
    124 
    125   string name_;
    126   int64 available_key_;
    127   vector<const char *> symbols_;
    128   std::unordered_map<int64, string> key_map_;
    129   std::unordered_map<string, int64> symbol_map_;
    130 
    131   mutable int ref_count_;
    132   mutable bool check_sum_finalized_;
    133   mutable MD5 check_sum_;
    134   mutable string check_sum_string_;
    135 
    136   DISALLOW_EVIL_CONSTRUCTORS(SymbolTableImpl);
    137 };
    138 
    139 
    140 class SymbolTableIterator;
    141 
    142 //
    143 // \class SymbolTable
    144 // \brief Symbol (string) to int and reverse mapping
    145 //
    146 // The SymbolTable implements the mappings of labels to strings and reverse.
    147 // SymbolTables are used to describe the alphabet of the input and output
    148 // labels for arcs in a Finite State Transducer.
    149 //
    150 // SymbolTables are reference counted and can therefore be shared across
    151 // multiple machines. For example a language model grammar G, with a
    152 // SymbolTable for the words in the language model can share this symbol
    153 // table with the lexical representation L o G.
    154 //
    155 class SymbolTable {
    156   friend class SymbolTableIterator;
    157  public:
    158   static const int64 kNoSymbol = -1;
    159 
    160   // Construct symbol table with a unique name.
    161   SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
    162 
    163   // Create a reference counted copy.
    164   SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
    165     impl_->IncrRefCount();
    166   }
    167 
    168   // Derefence implentation object. When reference count hits 0, delete
    169   // implementation.
    170   ~SymbolTable() {
    171     if (!impl_->DecrRefCount()) delete impl_;
    172   }
    173 
    174   // create a reference counted copy
    175   SymbolTable* Copy() const {
    176     return new SymbolTable(*this);
    177   }
    178 
    179   // Add a symbol with given key to table. A symbol table also
    180   // keeps track of the last available key (highest key value in
    181   // the symbol table).
    182   //
    183   // \param symbol string symbol to add
    184   // \param key associated key for string symbol
    185   // \return the key created by the symbol table. Symbols allready added to
    186   //         the symbol table will not get a different key.
    187   int64 AddSymbol(const string& symbol, int64 key) {
    188     return impl_->AddSymbol(symbol, key);
    189   }
    190 
    191   // Add a symbol to the table. The associated value key is automatically
    192   // assigned by the symbol table.
    193   //
    194   // \param symbol string to add to the table
    195   // \return the value key assigned to the associated string symbol
    196   int64 AddSymbol(const string& symbol) {
    197     return impl_->AddSymbol(symbol);
    198   }
    199 
    200   // Add another symbol table to this table. All key values will be offset
    201   // by the current available key (highest key value in the symbol table).
    202   // Note string symbols with the same key value with still have the same
    203   // key value after the symbol table has been merged, but a different
    204   // value. Adding symbol tables do not result in changes in the base table.
    205   //
    206   // Merging N symbol tables is often useful when combining the various
    207   // name spaces of transducers to a unified representation.
    208   //
    209   // \param table the symbol table to add to this table
    210   void AddTable(const SymbolTable& table) {
    211     return impl_->AddTable(table.impl_);
    212   }
    213 
    214   // return the name of the symbol table
    215   const string& Name() const {
    216     return impl_->Name();
    217   }
    218 
    219   // return the MD5 check-sum for this table. All new symbols added to
    220   // the table will result in an updated checksum.
    221   string CheckSum() const {
    222     return impl_->CheckSum();
    223   }
    224 
    225   // read an ascii representation of the symbol table
    226   static SymbolTable* ReadText(const string& filename) {
    227     SymbolTableImpl* impl = SymbolTableImpl::ReadText(filename);
    228     if (!impl)
    229       return 0;
    230     else
    231       return new SymbolTable(impl);
    232   }
    233 
    234   // read a binary dump of the symbol table
    235   static SymbolTable* Read(istream &strm, const string& source) {
    236     SymbolTableImpl* impl = SymbolTableImpl::Read(strm, source);
    237     if (!impl)
    238       return 0;
    239     else
    240       return new SymbolTable(impl);
    241   }
    242 
    243   // read a binary dump of the symbol table
    244   static SymbolTable* Read(const string& filename) {
    245     ifstream strm(filename.c_str());
    246     if (!strm) {
    247       LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
    248       return 0;
    249     }
    250     return Read(strm, filename);
    251   }
    252 
    253   bool Write(ostream  &strm) const {
    254     return impl_->Write(strm);
    255   }
    256 
    257   bool Write(const string& filename) const {
    258     ofstream strm(filename.c_str());
    259     if (!strm) {
    260       LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
    261       return false;
    262     }
    263     return Write(strm);
    264   }
    265 
    266   // Dump an ascii text representation of the symbol table
    267   bool WriteText(ostream &strm) const {
    268     return impl_->WriteText(strm);
    269   }
    270 
    271   // Dump an ascii text representation of the symbol table
    272   bool WriteText(const string& filename) const {
    273     ofstream strm(filename.c_str());
    274     if (!strm) {
    275       LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
    276       return false;
    277     }
    278     return WriteText(strm);
    279   }
    280 
    281   // Return the string associated with the key. If the key is out of
    282   // range (<0, >max), log error and return an empty string.
    283   string Find(int64 key) const {
    284     return impl_->Find(key);
    285   }
    286 
    287   // Return the key associated with the symbol. If the symbol
    288   // does not exists, log error and  return -1
    289   int64 Find(const string& symbol) const {
    290     return impl_->Find(symbol);
    291   }
    292 
    293   // Return the key associated with the symbol. If the symbol
    294   // does not exists, log error and  return -1
    295   int64 Find(const char* symbol) const {
    296     return impl_->Find(symbol);
    297   }
    298 
    299   // return the current available key (i.e highest key number) in
    300   // the symbol table
    301   int64 AvailableKey(void) const {
    302     return impl_->AvailableKey();
    303   }
    304 
    305  protected:
    306   explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
    307 
    308   const SymbolTableImpl* Impl() const {
    309     return impl_;
    310   }
    311 
    312  private:
    313   SymbolTableImpl* impl_;
    314 
    315 
    316   void operator=(const SymbolTable &table);  // disallow
    317 };
    318 
    319 
    320 //
    321 // \class SymbolTableIterator
    322 // \brief Iterator class for symbols in a symbol table
    323 class SymbolTableIterator {
    324  public:
    325   // Constructor creates a refcounted copy of underlying implementation
    326   SymbolTableIterator(const SymbolTable& symbol_table) {
    327     impl_ = symbol_table.Impl();
    328     impl_->IncrRefCount();
    329     pos_ = 0;
    330     size_ = impl_->symbols_.size();
    331   }
    332 
    333   // decrement implementation refcount, and delete if 0
    334   ~SymbolTableIterator() {
    335     if (!impl_->DecrRefCount()) delete impl_;
    336   }
    337 
    338   // is iterator done
    339   bool Done(void) {
    340     return (pos_ == size_);
    341   }
    342 
    343   // return the Value() of the current symbol (in64 key)
    344   int64 Value(void) {
    345     return impl_->Find(impl_->symbols_[pos_]);
    346   }
    347 
    348   // return the string of the current symbol
    349   const char* Symbol(void) {
    350     return impl_->symbols_[pos_];
    351   }
    352 
    353   // advance iterator forward
    354   void Next(void) {
    355     if (Done()) return;
    356     ++pos_;
    357   }
    358 
    359   // reset iterator
    360   void Reset(void) {
    361     pos_ = 0;
    362   }
    363 
    364  private:
    365   const SymbolTableImpl* impl_;
    366   size_t pos_;
    367   size_t size_;
    368 };
    369 
    370 
    371 // Tests compatibilty between two sets of symbol tables
    372 inline bool CompatSymbols(const SymbolTable *syms1,
    373                           const SymbolTable *syms2) {
    374   if (!FLAGS_fst_compat_symbols)
    375     return true;
    376   else if (!syms1 && !syms2)
    377     return true;
    378   else if ((syms1 && !syms2) || (!syms1 && syms2))
    379     return false;
    380   else
    381     return syms1->CheckSum() == syms2->CheckSum();
    382 }
    383 
    384 }  // namespace fst
    385 
    386 #endif  // FST_LIB_SYMBOL_TABLE_H__
    387