1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // All Rights Reserved. 16 // 17 // Author : Johan Schalkwyk 18 // 19 // \file 20 // Classes to provide symbol-to-integer and integer-to-symbol mappings. 21 22 #ifndef FST_LIB_SYMBOL_TABLE_H__ 23 #define FST_LIB_SYMBOL_TABLE_H__ 24 25 #include <cstring> 26 #include <string> 27 #include <utility> 28 using std::pair; using std::make_pair; 29 #include <vector> 30 using std::vector; 31 32 33 #include <fst/compat.h> 34 #include <iostream> 35 #include <fstream> 36 37 38 #include <map> 39 40 DECLARE_bool(fst_compat_symbols); 41 42 namespace fst { 43 44 // WARNING: Reading via symbol table read options should 45 // not be used. This is a temporary work around for 46 // reading symbol ranges of previously stored symbol sets. 47 struct SymbolTableReadOptions { 48 SymbolTableReadOptions() { } 49 50 SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_, 51 const string& source_) 52 : string_hash_ranges(string_hash_ranges_), 53 source(source_) { } 54 55 vector<pair<int64, int64> > string_hash_ranges; 56 string source; 57 }; 58 59 class SymbolTableImpl { 60 public: 61 SymbolTableImpl(const string &name) 62 : name_(name), 63 available_key_(0), 64 dense_key_limit_(0), 65 check_sum_finalized_(false) {} 66 67 explicit SymbolTableImpl(const SymbolTableImpl& impl) 68 : name_(impl.name_), 69 available_key_(0), 70 dense_key_limit_(0), 71 check_sum_finalized_(false) { 72 for (size_t i = 0; i < impl.symbols_.size(); ++i) { 73 AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i])); 74 } 75 } 76 77 ~SymbolTableImpl() { 78 for (size_t i = 0; i < symbols_.size(); ++i) 79 delete[] symbols_[i]; 80 } 81 82 // TODO(johans): Add flag to specify whether the symbol 83 // should be indexed as string or int or both. 84 int64 AddSymbol(const string& symbol, int64 key); 85 86 int64 AddSymbol(const string& symbol) { 87 int64 key = Find(symbol); 88 return (key == -1) ? AddSymbol(symbol, available_key_++) : key; 89 } 90 91 static SymbolTableImpl* ReadText(istream &strm, 92 const string &name, 93 bool allow_negative = false); 94 95 static SymbolTableImpl* Read(istream &strm, 96 const SymbolTableReadOptions& opts); 97 98 bool Write(ostream &strm) const; 99 100 // 101 // Return the string associated with the key. If the key is out of 102 // range (<0, >max), return an empty string. 103 string Find(int64 key) const { 104 if (key >=0 && key < dense_key_limit_) 105 return string(symbols_[key]); 106 107 map<int64, const char*>::const_iterator it = 108 key_map_.find(key); 109 if (it == key_map_.end()) { 110 return ""; 111 } 112 return string(it->second); 113 } 114 115 // 116 // Return the key associated with the symbol. If the symbol 117 // does not exists, return SymbolTable::kNoSymbol. 118 int64 Find(const string& symbol) const { 119 return Find(symbol.c_str()); 120 } 121 122 // 123 // Return the key associated with the symbol. If the symbol 124 // does not exists, return SymbolTable::kNoSymbol. 125 int64 Find(const char* symbol) const { 126 map<const char *, int64, StrCmp>::const_iterator it = 127 symbol_map_.find(symbol); 128 if (it == symbol_map_.end()) { 129 return -1; 130 } 131 return it->second; 132 } 133 134 int64 GetNthKey(ssize_t pos) const { 135 if ((pos < 0) || (pos >= symbols_.size())) return -1; 136 else return Find(symbols_[pos]); 137 } 138 139 const string& Name() const { return name_; } 140 141 int IncrRefCount() const { 142 return ref_count_.Incr(); 143 } 144 int DecrRefCount() const { 145 return ref_count_.Decr(); 146 } 147 int RefCount() const { 148 return ref_count_.count(); 149 } 150 151 string CheckSum() const { 152 MutexLock check_sum_lock(&check_sum_mutex_); 153 MaybeRecomputeCheckSum(); 154 return check_sum_string_; 155 } 156 157 string LabeledCheckSum() const { 158 MutexLock check_sum_lock(&check_sum_mutex_); 159 MaybeRecomputeCheckSum(); 160 return labeled_check_sum_string_; 161 } 162 163 int64 AvailableKey() const { 164 return available_key_; 165 } 166 167 size_t NumSymbols() const { 168 return symbols_.size(); 169 } 170 171 private: 172 // Recomputes the checksums (both of them) if we've had changes since the last 173 // computation (i.e., if check_sum_finalized_ is false). 174 void MaybeRecomputeCheckSum() const; 175 176 struct StrCmp { 177 bool operator()(const char *s1, const char *s2) const { 178 return strcmp(s1, s2) < 0; 179 } 180 }; 181 182 string name_; 183 int64 available_key_; 184 int64 dense_key_limit_; 185 vector<const char *> symbols_; 186 map<int64, const char*> key_map_; 187 map<const char *, int64, StrCmp> symbol_map_; 188 189 mutable RefCounter ref_count_; 190 mutable bool check_sum_finalized_; 191 mutable CheckSummer check_sum_; 192 mutable CheckSummer labeled_check_sum_; 193 mutable string check_sum_string_; 194 mutable string labeled_check_sum_string_; 195 mutable Mutex check_sum_mutex_; 196 }; 197 198 // 199 // \class SymbolTable 200 // \brief Symbol (string) to int and reverse mapping 201 // 202 // The SymbolTable implements the mappings of labels to strings and reverse. 203 // SymbolTables are used to describe the alphabet of the input and output 204 // labels for arcs in a Finite State Transducer. 205 // 206 // SymbolTables are reference counted and can therefore be shared across 207 // multiple machines. For example a language model grammar G, with a 208 // SymbolTable for the words in the language model can share this symbol 209 // table with the lexical representation L o G. 210 // 211 class SymbolTable { 212 public: 213 static const int64 kNoSymbol = -1; 214 215 // Construct symbol table with a unique name. 216 SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {} 217 218 // Create a reference counted copy. 219 SymbolTable(const SymbolTable& table) : impl_(table.impl_) { 220 impl_->IncrRefCount(); 221 } 222 223 // Derefence implentation object. When reference count hits 0, delete 224 // implementation. 225 virtual ~SymbolTable() { 226 if (!impl_->DecrRefCount()) delete impl_; 227 } 228 229 // Read an ascii representation of the symbol table from an istream. Pass a 230 // name to give the resulting SymbolTable. 231 static SymbolTable* ReadText(istream &strm, 232 const string& name, 233 bool allow_negative = false) { 234 SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, 235 name, 236 allow_negative); 237 if (!impl) 238 return 0; 239 else 240 return new SymbolTable(impl); 241 } 242 243 // read an ascii representation of the symbol table 244 static SymbolTable* ReadText(const string& filename, 245 bool allow_negative = false) { 246 ifstream strm(filename.c_str(), ifstream::in); 247 if (!strm) { 248 LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename; 249 return 0; 250 } 251 return ReadText(strm, filename, allow_negative); 252 } 253 254 255 // WARNING: Reading via symbol table read options should 256 // not be used. This is a temporary work around. 257 static SymbolTable* Read(istream &strm, 258 const SymbolTableReadOptions& opts) { 259 SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts); 260 if (!impl) 261 return 0; 262 else 263 return new SymbolTable(impl); 264 } 265 266 // read a binary dump of the symbol table from a stream 267 static SymbolTable* Read(istream &strm, const string& source) { 268 SymbolTableReadOptions opts; 269 opts.source = source; 270 return Read(strm, opts); 271 } 272 273 // read a binary dump of the symbol table 274 static SymbolTable* Read(const string& filename) { 275 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); 276 if (!strm) { 277 LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename; 278 return 0; 279 } 280 return Read(strm, filename); 281 } 282 283 //-------------------------------------------------------- 284 // Derivable Interface (final) 285 //-------------------------------------------------------- 286 // create a reference counted copy 287 virtual SymbolTable* Copy() const { 288 return new SymbolTable(*this); 289 } 290 291 // Add a symbol with given key to table. A symbol table also 292 // keeps track of the last available key (highest key value in 293 // the symbol table). 294 virtual int64 AddSymbol(const string& symbol, int64 key) { 295 MutateCheck(); 296 return impl_->AddSymbol(symbol, key); 297 } 298 299 // Add a symbol to the table. The associated value key is automatically 300 // assigned by the symbol table. 301 virtual int64 AddSymbol(const string& symbol) { 302 MutateCheck(); 303 return impl_->AddSymbol(symbol); 304 } 305 306 // Add another symbol table to this table. All key values will be offset 307 // by the current available key (highest key value in the symbol table). 308 // Note string symbols with the same key value with still have the same 309 // key value after the symbol table has been merged, but a different 310 // value. Adding symbol tables do not result in changes in the base table. 311 virtual void AddTable(const SymbolTable& table); 312 313 // return the name of the symbol table 314 virtual const string& Name() const { 315 return impl_->Name(); 316 } 317 318 // Return the label-agnostic MD5 check-sum for this table. All new symbols 319 // added to the table will result in an updated checksum. 320 // DEPRECATED. 321 virtual string CheckSum() const { 322 return impl_->CheckSum(); 323 } 324 325 // Same as CheckSum(), but this returns an label-dependent version. 326 virtual string LabeledCheckSum() const { 327 return impl_->LabeledCheckSum(); 328 } 329 330 virtual bool Write(ostream &strm) const { 331 return impl_->Write(strm); 332 } 333 334 bool Write(const string& filename) const { 335 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); 336 if (!strm) { 337 LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename; 338 return false; 339 } 340 return Write(strm); 341 } 342 343 // Dump an ascii text representation of the symbol table via a stream 344 virtual bool WriteText(ostream &strm) const; 345 346 // Dump an ascii text representation of the symbol table 347 bool WriteText(const string& filename) const { 348 ofstream strm(filename.c_str()); 349 if (!strm) { 350 LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename; 351 return false; 352 } 353 return WriteText(strm); 354 } 355 356 // Return the string associated with the key. If the key is out of 357 // range (<0, >max), log error and return an empty string. 358 virtual string Find(int64 key) const { 359 return impl_->Find(key); 360 } 361 362 // Return the key associated with the symbol. If the symbol 363 // does not exists, log error and return SymbolTable::kNoSymbol 364 virtual int64 Find(const string& symbol) const { 365 return impl_->Find(symbol); 366 } 367 368 // Return the key associated with the symbol. If the symbol 369 // does not exists, log error and return SymbolTable::kNoSymbol 370 virtual int64 Find(const char* symbol) const { 371 return impl_->Find(symbol); 372 } 373 374 // Return the current available key (i.e highest key number+1) in 375 // the symbol table 376 virtual int64 AvailableKey(void) const { 377 return impl_->AvailableKey(); 378 } 379 380 // Return the current number of symbols in table (not necessarily 381 // equal to AvailableKey()) 382 virtual size_t NumSymbols(void) const { 383 return impl_->NumSymbols(); 384 } 385 386 virtual int64 GetNthKey(ssize_t pos) const { 387 return impl_->GetNthKey(pos); 388 } 389 390 private: 391 explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {} 392 393 void MutateCheck() { 394 // Copy on write 395 if (impl_->RefCount() > 1) { 396 impl_->DecrRefCount(); 397 impl_ = new SymbolTableImpl(*impl_); 398 } 399 } 400 401 const SymbolTableImpl* Impl() const { 402 return impl_; 403 } 404 405 private: 406 SymbolTableImpl* impl_; 407 408 void operator=(const SymbolTable &table); // disallow 409 }; 410 411 412 // 413 // \class SymbolTableIterator 414 // \brief Iterator class for symbols in a symbol table 415 class SymbolTableIterator { 416 public: 417 SymbolTableIterator(const SymbolTable& table) 418 : table_(table), 419 pos_(0), 420 nsymbols_(table.NumSymbols()), 421 key_(table.GetNthKey(0)) { } 422 423 ~SymbolTableIterator() { } 424 425 // is iterator done 426 bool Done(void) { 427 return (pos_ == nsymbols_); 428 } 429 430 // return the Value() of the current symbol (int64 key) 431 int64 Value(void) { 432 return key_; 433 } 434 435 // return the string of the current symbol 436 string Symbol(void) { 437 return table_.Find(key_); 438 } 439 440 // advance iterator forward 441 void Next(void) { 442 ++pos_; 443 if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_); 444 } 445 446 // reset iterator 447 void Reset(void) { 448 pos_ = 0; 449 key_ = table_.GetNthKey(0); 450 } 451 452 private: 453 const SymbolTable& table_; 454 ssize_t pos_; 455 size_t nsymbols_; 456 int64 key_; 457 }; 458 459 460 // Tests compatibilty between two sets of symbol tables 461 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, 462 bool warning = true) { 463 if (!FLAGS_fst_compat_symbols) { 464 return true; 465 } else if (!syms1 && !syms2) { 466 return true; 467 } else if (syms1 && !syms2) { 468 if (warning) 469 LOG(WARNING) << 470 "CompatSymbols: first symbol table present but second missing"; 471 return false; 472 } else if (!syms1 && syms2) { 473 if (warning) 474 LOG(WARNING) << 475 "CompatSymbols: second symbol table present but first missing"; 476 return false; 477 } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) { 478 if (warning) 479 LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match"; 480 return false; 481 } else { 482 return true; 483 } 484 } 485 486 487 // Relabels a symbol table as specified by the input vector of pairs 488 // (old label, new label). The new symbol table only retains symbols 489 // for which a relabeling is *explicitely* specified. 490 // TODO(allauzen): consider adding options to allow for some form 491 // of implicit identity relabeling. 492 template <class Label> 493 SymbolTable *RelabelSymbolTable(const SymbolTable *table, 494 const vector<pair<Label, Label> > &pairs) { 495 SymbolTable *new_table = new SymbolTable( 496 table->Name().empty() ? string() : 497 (string("relabeled_") + table->Name())); 498 499 for (size_t i = 0; i < pairs.size(); ++i) 500 new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second); 501 502 return new_table; 503 } 504 505 } // namespace fst 506 507 #endif // FST_LIB_SYMBOL_TABLE_H__ 508