Home | History | Annotate | Download | only in far
      1 // sttable.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // A generic string-to-type table file format
     20 //
     21 // This is not meant as a generalization of SSTable. This is more of
     22 // a simple replacement for SSTable in order to provide an open-source
     23 // implementation of the FAR format for the external version of the
     24 // FST Library.
     25 
     26 #ifndef FST_EXTENSIONS_FAR_STTABLE_H_
     27 #define FST_EXTENSIONS_FAR_STTABLE_H_
     28 
     29 #include <algorithm>
     30 #include <iostream>
     31 #include <fstream>
     32 #include <fst/util.h>
     33 
     34 namespace fst {
     35 
     36 static const int32 kSTTableMagicNumber = 2125656924;
     37 static const int32 kSTTableFileVersion = 1;
     38 
     39 // String-to-type table writing class for object of type 'T' using functor 'W'
     40 // to write an object of type 'T' from a stream. 'W' must conform to the
     41 // following interface:
     42 //
     43 //   struct Writer {
     44 //     void operator()(ostream &, const T &) const;
     45 //   };
     46 //
     47 template <class T, class W>
     48 class STTableWriter {
     49  public:
     50   typedef T EntryType;
     51   typedef W EntryWriter;
     52 
     53   explicit STTableWriter(const string &filename)
     54       : stream_(filename.c_str(), ofstream::out | ofstream::binary),
     55         error_(false) {
     56     WriteType(stream_, kSTTableMagicNumber);
     57     WriteType(stream_, kSTTableFileVersion);
     58     if (!stream_) {
     59       FSTERROR() << "STTableWriter::STTableWriter: error writing to file: "
     60                  << filename;
     61       error_=true;
     62     }
     63   }
     64 
     65   static STTableWriter<T, W> *Create(const string &filename) {
     66     if (filename.empty()) {
     67       LOG(ERROR) << "STTableWriter: writing to standard out unsupported.";
     68       return 0;
     69     }
     70     return new STTableWriter<T, W>(filename);
     71   }
     72 
     73   void Add(const string &key, const T &t) {
     74     if (key == "") {
     75       FSTERROR() << "STTableWriter::Add: key empty: " << key;
     76       error_ = true;
     77     } else if (key < last_key_) {
     78       FSTERROR() << "STTableWriter::Add: key disorder: " << key;
     79       error_ = true;
     80     }
     81     if (error_) return;
     82     last_key_ = key;
     83     positions_.push_back(stream_.tellp());
     84     WriteType(stream_, key);
     85     entry_writer_(stream_, t);
     86   }
     87 
     88   bool Error() const { return error_; }
     89 
     90   ~STTableWriter() {
     91     WriteType(stream_, positions_);
     92     WriteType(stream_, static_cast<int64>(positions_.size()));
     93   }
     94 
     95  private:
     96   EntryWriter entry_writer_;  // Write functor for 'EntryType'
     97   ofstream stream_;           // Output stream
     98   vector<int64> positions_;   // Position in file of each key-entry pair
     99   string last_key_;           // Last key
    100   bool error_;
    101 
    102   DISALLOW_COPY_AND_ASSIGN(STTableWriter);
    103 };
    104 
    105 
    106 // String-to-type table reading class for object of type 'T' using functor 'R'
    107 // to read an object of type 'T' form a stream. 'R' must conform to the
    108 // following interface:
    109 //
    110 //   struct Reader {
    111 //     T *operator()(istream &) const;
    112 //   };
    113 //
    114 template <class T, class R>
    115 class STTableReader {
    116  public:
    117   typedef T EntryType;
    118   typedef R EntryReader;
    119 
    120   explicit STTableReader(const vector<string> &filenames)
    121       : sources_(filenames), entry_(0), error_(false) {
    122     compare_ = new Compare(&keys_);
    123     keys_.resize(filenames.size());
    124     streams_.resize(filenames.size(), 0);
    125     positions_.resize(filenames.size());
    126     for (size_t i = 0; i < filenames.size(); ++i) {
    127       streams_[i] = new ifstream(
    128           filenames[i].c_str(), ifstream::in | ifstream::binary);
    129       int32 magic_number = 0, file_version = 0;
    130       ReadType(*streams_[i], &magic_number);
    131       ReadType(*streams_[i], &file_version);
    132       if (magic_number != kSTTableMagicNumber) {
    133         FSTERROR() << "STTableReader::STTableReader: wrong file type: "
    134                    << filenames[i];
    135         error_ = true;
    136         return;
    137       }
    138       if (file_version != kSTTableFileVersion) {
    139         FSTERROR() << "STTableReader::STTableReader: wrong file version: "
    140                    << filenames[i];
    141         error_ = true;
    142         return;
    143       }
    144       int64 num_entries;
    145       streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
    146       ReadType(*streams_[i], &num_entries);
    147       streams_[i]->seekg(-static_cast<int>(sizeof(int64)) *
    148                          (num_entries + 1), ios_base::end);
    149       positions_[i].resize(num_entries);
    150       for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j)
    151         ReadType(*streams_[i], &(positions_[i][j]));
    152       streams_[i]->seekg(positions_[i][0]);
    153       if (!*streams_[i]) {
    154         FSTERROR() << "STTableReader::STTableReader: error reading file: "
    155                    << filenames[i];
    156         error_ = true;
    157         return;
    158       }
    159 
    160     }
    161     MakeHeap();
    162   }
    163 
    164   ~STTableReader() {
    165     for (size_t i = 0; i < streams_.size(); ++i)
    166       delete streams_[i];
    167     delete compare_;
    168     if (entry_)
    169       delete entry_;
    170   }
    171 
    172   static STTableReader<T, R> *Open(const string &filename) {
    173     if (filename.empty()) {
    174       LOG(ERROR) << "STTableReader: reading from standard in not supported";
    175       return 0;
    176     }
    177     vector<string> filenames;
    178     filenames.push_back(filename);
    179     return new STTableReader<T, R>(filenames);
    180   }
    181 
    182   static STTableReader<T, R> *Open(const vector<string> &filenames) {
    183     return new STTableReader<T, R>(filenames);
    184   }
    185 
    186   void Reset() {
    187     if (error_) return;
    188     for (size_t i = 0; i < streams_.size(); ++i)
    189       streams_[i]->seekg(positions_[i].front());
    190     MakeHeap();
    191   }
    192 
    193   bool Find(const string &key) {
    194     if (error_) return false;
    195     for (size_t i = 0; i < streams_.size(); ++i)
    196       LowerBound(i, key);
    197     MakeHeap();
    198     return keys_[current_] == key;
    199   }
    200 
    201   bool Done() const { return error_ || heap_.empty(); }
    202 
    203   void Next() {
    204     if (error_) return;
    205     if (streams_[current_]->tellg() <= positions_[current_].back()) {
    206       ReadType(*(streams_[current_]), &(keys_[current_]));
    207       if (!*streams_[current_]) {
    208         FSTERROR() << "STTableReader: error reading file: "
    209                    << sources_[current_];
    210         error_ = true;
    211         return;
    212       }
    213       push_heap(heap_.begin(), heap_.end(), *compare_);
    214     } else {
    215       heap_.pop_back();
    216     }
    217     if (!heap_.empty())
    218       PopHeap();
    219   }
    220 
    221   const string &GetKey() const {
    222     return keys_[current_];
    223   }
    224 
    225   const EntryType &GetEntry() const {
    226     return *entry_;
    227   }
    228 
    229   bool Error() const { return error_; }
    230 
    231  private:
    232   // Comparison functor used to compare stream IDs in the heap
    233   struct Compare {
    234     Compare(const vector<string> *keys) : keys_(keys) {}
    235 
    236     bool operator()(size_t i, size_t j) const {
    237       return (*keys_)[i] > (*keys_)[j];
    238     };
    239 
    240    private:
    241     const vector<string> *keys_;
    242   };
    243 
    244   // Position the stream with ID 'id' at the position corresponding
    245   // to the lower bound for key 'find_key'
    246   void LowerBound(size_t id, const string &find_key) {
    247     ifstream *strm = streams_[id];
    248     const vector<int64> &positions = positions_[id];
    249     size_t low = 0, high = positions.size() - 1;
    250 
    251     while (low < high) {
    252       size_t mid = (low + high)/2;
    253       strm->seekg(positions[mid]);
    254       string key;
    255       ReadType(*strm, &key);
    256       if (key > find_key) {
    257         high = mid;
    258       } else if (key < find_key) {
    259         low = mid + 1;
    260       } else {
    261         for (size_t i = mid; i > low; --i) {
    262           strm->seekg(positions[i - 1]);
    263           ReadType(*strm, &key);
    264           if (key != find_key) {
    265             strm->seekg(positions[i]);
    266             return;
    267           }
    268         }
    269         strm->seekg(positions[low]);
    270         return;
    271       }
    272     }
    273     strm->seekg(positions[low]);
    274   }
    275 
    276   // Add all streams to the heap
    277   void MakeHeap() {
    278     heap_.clear();
    279     for (size_t i = 0; i < streams_.size(); ++i) {
    280       ReadType(*streams_[i], &(keys_[i]));
    281       if (!*streams_[i]) {
    282         FSTERROR() << "STTableReader: error reading file: " << sources_[i];
    283         error_ = true;
    284         return;
    285       }
    286       heap_.push_back(i);
    287     }
    288     make_heap(heap_.begin(), heap_.end(), *compare_);
    289     PopHeap();
    290   }
    291 
    292   // Position the stream with the lowest key at the top
    293   // of the heap, set 'current_' to the ID of that stream
    294   // and read the current entry from that stream
    295   void PopHeap() {
    296     pop_heap(heap_.begin(), heap_.end(), *compare_);
    297     current_ = heap_.back();
    298     if (entry_)
    299       delete entry_;
    300     entry_ = entry_reader_(*streams_[current_]);
    301     if (!entry_)
    302       error_ = true;
    303     if (!*streams_[current_]) {
    304       FSTERROR() << "STTableReader: error reading entry for key: "
    305                  << keys_[current_] << ", file: " << sources_[current_];
    306       error_ = true;
    307     }
    308   }
    309 
    310 
    311   EntryReader entry_reader_;   // Read functor for 'EntryType'
    312   vector<ifstream*> streams_;  // Input streams
    313   vector<string> sources_;     // and corresponding file names
    314   vector<vector<int64> > positions_;  // Index of positions for each stream
    315   vector<string> keys_;  // Lowest unread key for each stream
    316   vector<int64> heap_;   // Heap containing ID of streams with unread keys
    317   int64 current_;        // Id of current stream to be read
    318   Compare *compare_;     // Functor comparing stream IDs for the heap
    319   mutable EntryType *entry_;  // Pointer to the currently read entry
    320   bool error_;
    321 
    322   DISALLOW_COPY_AND_ASSIGN(STTableReader);
    323 };
    324 
    325 
    326 // String-to-type table header reading function template on the entry header
    327 // type 'H' having a member function:
    328 //   Read(istream &strm, const string &filename);
    329 // Checks that 'filename' is an STTable and call the H::Read() on the last
    330 // entry in the STTable.
    331 template <class H>
    332 bool ReadSTTableHeader(const string &filename, H *header) {
    333   ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    334   int32 magic_number = 0, file_version = 0;
    335   ReadType(strm, &magic_number);
    336   ReadType(strm, &file_version);
    337   if (magic_number != kSTTableMagicNumber) {
    338     LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
    339     return false;
    340   }
    341   if (file_version != kSTTableFileVersion) {
    342     LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
    343     return false;
    344   }
    345   int64 i = -1;
    346   strm.seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
    347   ReadType(strm, &i);  // Read number of entries
    348   if (!strm) {
    349     LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
    350     return false;
    351   }
    352   if (i == 0) return true;  // No entry header to read
    353   strm.seekg(-2 * static_cast<int>(sizeof(int64)), ios_base::end);
    354   ReadType(strm, &i);  // Read position for last entry in file
    355   strm.seekg(i);
    356   string key;
    357   ReadType(strm, &key);
    358   header->Read(strm, filename + ":" + key);
    359   if (!strm) {
    360     LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
    361     return false;
    362   }
    363   return true;
    364 }
    365 
    366 bool IsSTTable(const string &filename);
    367 
    368 }  // namespace fst
    369 
    370 #endif  // FST_EXTENSIONS_FAR_STTABLE_H_
    371