Home | History | Annotate | Download | only in lib
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // All Rights Reserved.
     16 //
     17 // Author : Johan Schalkwyk
     18 //
     19 // \file
     20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
     21 
     22 #include <fst/symbol-table.h>
     23 
     24 #include <fst/util.h>
     25 
     26 DEFINE_bool(fst_compat_symbols, true,
     27             "Require symbol tables to match when appropriate");
     28 DEFINE_string(fst_field_separator, "\t ",
     29               "Set of characters used as a separator between printed fields");
     30 
     31 namespace fst {
     32 
     33 // Maximum line length in textual symbols file.
     34 const int kLineLen = 8096;
     35 
     36 // Identifies stream data as a symbol table (and its endianity)
     37 static const int32 kSymbolTableMagicNumber = 2125658996;
     38 
     39 SymbolTableTextOptions::SymbolTableTextOptions()
     40     : allow_negative(false), fst_field_separator(FLAGS_fst_field_separator) { }
     41 
     42 SymbolTableImpl* SymbolTableImpl::ReadText(istream &strm,
     43                                            const string &filename,
     44                                            const SymbolTableTextOptions &opts) {
     45   SymbolTableImpl* impl = new SymbolTableImpl(filename);
     46 
     47   int64 nline = 0;
     48   char line[kLineLen];
     49   while (strm.getline(line, kLineLen)) {
     50     ++nline;
     51     vector<char *> col;
     52     string separator = opts.fst_field_separator + "\n";
     53     SplitToVector(line, separator.c_str(), &col, true);
     54     if (col.size() == 0)  // empty line
     55       continue;
     56     if (col.size() != 2) {
     57       LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
     58                  << col.size() << "), "
     59                  << "file = " << filename << ", line = " << nline
     60                  << ":<" << line << ">";
     61       delete impl;
     62       return 0;
     63     }
     64     const char *symbol = col[0];
     65     const char *value = col[1];
     66     char *p;
     67     int64 key = strtoll(value, &p, 10);
     68     if (p < value + strlen(value) ||
     69         (!opts.allow_negative && key < 0) || key == -1) {
     70       LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
     71                  << value << "\", "
     72                  << "file = " << filename << ", line = " << nline;
     73       delete impl;
     74       return 0;
     75     }
     76     impl->AddSymbol(symbol, key);
     77   }
     78 
     79   return impl;
     80 }
     81 
     82 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
     83   {
     84     ReaderMutexLock check_sum_lock(&check_sum_mutex_);
     85     if (check_sum_finalized_)
     86       return;
     87   }
     88 
     89   // We'll aquire an exclusive lock to recompute the checksums.
     90   MutexLock check_sum_lock(&check_sum_mutex_);
     91   if (check_sum_finalized_)  // Another thread (coming in around the same time
     92     return;                  // might have done it already).  So we recheck.
     93 
     94   // Calculate the original label-agnostic check sum.
     95   CheckSummer check_sum;
     96   for (int64 i = 0; i < symbols_.size(); ++i)
     97     check_sum.Update(symbols_[i], strlen(symbols_[i]) + 1);
     98   check_sum_string_ = check_sum.Digest();
     99 
    100   // Calculate the safer, label-dependent check sum.
    101   CheckSummer labeled_check_sum;
    102   for (int64 key = 0; key < dense_key_limit_; ++key) {
    103     ostringstream line;
    104     line << symbols_[key] << '\t' << key;
    105     labeled_check_sum.Update(line.str().data(), line.str().size());
    106   }
    107   for (map<int64, const char*>::const_iterator it =
    108        key_map_.begin();
    109        it != key_map_.end();
    110        ++it) {
    111     if (it->first >= dense_key_limit_) {
    112       ostringstream line;
    113       line << it->second << '\t' << it->first;
    114       labeled_check_sum.Update(line.str().data(), line.str().size());
    115     }
    116   }
    117   labeled_check_sum_string_ = labeled_check_sum.Digest();
    118 
    119   check_sum_finalized_ = true;
    120 }
    121 
    122 int64 SymbolTableImpl::AddSymbol(const string& symbol, int64 key) {
    123   map<const char *, int64, StrCmp>::const_iterator it =
    124       symbol_map_.find(symbol.c_str());
    125   if (it == symbol_map_.end()) {  // only add if not in table
    126     check_sum_finalized_ = false;
    127 
    128     char *csymbol = new char[symbol.size() + 1];
    129     strcpy(csymbol, symbol.c_str());
    130     symbols_.push_back(csymbol);
    131     key_map_[key] = csymbol;
    132     symbol_map_[csymbol] = key;
    133 
    134     if (key >= available_key_) {
    135       available_key_ = key + 1;
    136     }
    137   } else {
    138     // Log if symbol already in table with different key
    139     if (it->second != key) {
    140       VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
    141               << " already in symbol_map_ with key = "
    142               << it->second
    143               << " but supplied new key = " << key
    144               << " (ignoring new key)";
    145     }
    146   }
    147   return key;
    148 }
    149 
    150 static bool IsInRange(const vector<pair<int64, int64> >& ranges,
    151                       int64 key) {
    152   if (ranges.size() == 0) return true;
    153   for (size_t i = 0; i < ranges.size(); ++i) {
    154     if (key >= ranges[i].first && key <= ranges[i].second)
    155       return true;
    156   }
    157   return false;
    158 }
    159 
    160 SymbolTableImpl* SymbolTableImpl::Read(istream &strm,
    161                                        const SymbolTableReadOptions& opts) {
    162   int32 magic_number = 0;
    163   ReadType(strm, &magic_number);
    164   if (!strm) {
    165     LOG(ERROR) << "SymbolTable::Read: read failed";
    166     return 0;
    167   }
    168   string name;
    169   ReadType(strm, &name);
    170   SymbolTableImpl* impl = new SymbolTableImpl(name);
    171   ReadType(strm, &impl->available_key_);
    172   int64 size;
    173   ReadType(strm, &size);
    174   if (!strm) {
    175     LOG(ERROR) << "SymbolTable::Read: read failed";
    176     delete impl;
    177     return 0;
    178   }
    179 
    180   string symbol;
    181   int64 key;
    182   impl->check_sum_finalized_ = false;
    183   for (size_t i = 0; i < size; ++i) {
    184     ReadType(strm, &symbol);
    185     ReadType(strm, &key);
    186     if (!strm) {
    187       LOG(ERROR) << "SymbolTable::Read: read failed";
    188       delete impl;
    189       return 0;
    190     }
    191 
    192     char *csymbol = new char[symbol.size() + 1];
    193     strcpy(csymbol, symbol.c_str());
    194     impl->symbols_.push_back(csymbol);
    195     if (key == impl->dense_key_limit_ &&
    196         key == impl->symbols_.size() - 1)
    197       impl->dense_key_limit_ = impl->symbols_.size();
    198     else
    199       impl->key_map_[key] = csymbol;
    200 
    201     if (IsInRange(opts.string_hash_ranges, key)) {
    202       impl->symbol_map_[csymbol] = key;
    203     }
    204   }
    205   return impl;
    206 }
    207 
    208 bool SymbolTableImpl::Write(ostream &strm) const {
    209   WriteType(strm, kSymbolTableMagicNumber);
    210   WriteType(strm, name_);
    211   WriteType(strm, available_key_);
    212   int64 size = symbols_.size();
    213   WriteType(strm, size);
    214   // first write out dense keys
    215   int64 i = 0;
    216   for (; i < dense_key_limit_; ++i) {
    217     WriteType(strm, string(symbols_[i]));
    218     WriteType(strm, i);
    219   }
    220   // next write out the remaining non densely packed keys
    221   for (map<const char *, int64, StrCmp>::const_iterator it =
    222            symbol_map_.begin(); it != symbol_map_.end(); ++it) {
    223     if ((it->second >= 0) && (it->second < dense_key_limit_))
    224       continue;
    225     WriteType(strm, string(it->first));
    226     WriteType(strm, it->second);
    227     ++i;
    228   }
    229   if (i != size) {
    230     LOG(ERROR) << "SymbolTable::Write:  write failed";
    231     return false;
    232   }
    233   strm.flush();
    234   if (!strm) {
    235     LOG(ERROR) << "SymbolTable::Write: write failed";
    236     return false;
    237   }
    238   return true;
    239 }
    240 
    241 const int64 SymbolTable::kNoSymbol;
    242 
    243 
    244 void SymbolTable::AddTable(const SymbolTable& table) {
    245   for (SymbolTableIterator iter(table); !iter.Done(); iter.Next())
    246     impl_->AddSymbol(iter.Symbol());
    247 }
    248 
    249 bool SymbolTable::WriteText(ostream &strm,
    250                             const SymbolTableTextOptions &opts) const {
    251   if (opts.fst_field_separator.empty()) {
    252     LOG(ERROR) << "Missing required field separator";
    253     return false;
    254   }
    255   bool once_only = false;
    256   for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
    257     ostringstream line;
    258     if (iter.Value() < 0 && !opts.allow_negative && !once_only) {
    259       LOG(WARNING) << "Negative symbol table entry when not allowed";
    260       once_only = true;
    261     }
    262     line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
    263          << '\n';
    264     strm.write(line.str().data(), line.str().length());
    265   }
    266   return true;
    267 }
    268 }  // namespace fst
    269