Home | History | Annotate | Download | only in far
      1 // extract-main.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 // Modified: jpr (at) google.com (Jake Ratkiewicz) to use the new arc-dispatch
     18 
     19 // \file
     20 // Extracts component FSTs from an finite-state archive.
     21 //
     22 
     23 #ifndef FST_EXTENSIONS_FAR_EXTRACT_H__
     24 #define FST_EXTENSIONS_FAR_EXTRACT_H__
     25 
     26 #include <string>
     27 #include <vector>
     28 using std::vector;
     29 
     30 #include <fst/extensions/far/far.h>
     31 
     32 namespace fst {
     33 
     34 template<class Arc>
     35 inline void FarWriteFst(const Fst<Arc>* fst, string key,
     36                         string* okey, int* nrep,
     37                         const int32 &generate_filenames, int i,
     38                         const string &filename_prefix,
     39                         const string &filename_suffix) {
     40   if (key == *okey)
     41     ++*nrep;
     42   else
     43     *nrep = 0;
     44 
     45   *okey = key;
     46 
     47   string ofilename;
     48   if (generate_filenames) {
     49     ostringstream tmp;
     50     tmp.width(generate_filenames);
     51     tmp.fill('0');
     52     tmp << i;
     53     ofilename = tmp.str();
     54   } else {
     55     if (*nrep > 0) {
     56       ostringstream tmp;
     57       tmp << '.' << nrep;
     58       key.append(tmp.str().data(), tmp.str().size());
     59     }
     60     ofilename = key;
     61   }
     62   fst->Write(filename_prefix + ofilename + filename_suffix);
     63 }
     64 
     65 template<class Arc>
     66 void FarExtract(const vector<string> &ifilenames,
     67                 const int32 &generate_filenames,
     68                 const string &keys,
     69                 const string &key_separator,
     70                 const string &range_delimiter,
     71                 const string &filename_prefix,
     72                 const string &filename_suffix) {
     73   FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
     74   if (!far_reader) return;
     75 
     76   string okey;
     77   int nrep = 0;
     78 
     79   vector<char *> key_vector;
     80   // User has specified a set of fsts to extract, where some of the "fsts" could
     81   // be ranges.
     82   if (!keys.empty()) {
     83     char *keys_cstr = new char[keys.size()+1];
     84     strcpy(keys_cstr, keys.c_str());
     85     SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true);
     86     int i = 0;
     87     for (int k = 0; k < key_vector.size(); ++k, ++i) {
     88       string key = string(key_vector[k]);
     89       char *key_cstr = new char[key.size()+1];
     90       strcpy(key_cstr, key.c_str());
     91       vector<char *> range_vector;
     92       SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false);
     93       if (range_vector.size() == 1) {  // Not a range
     94         if (!far_reader->Find(key)) {
     95           LOG(ERROR) << "FarExtract: Cannot find key: " << key;
     96           return;
     97         }
     98         const Fst<Arc> &fst = far_reader->GetFst();
     99         FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
    100                     filename_prefix, filename_suffix);
    101       } else if (range_vector.size() == 2) {  // A legal range
    102         string begin_key = string(range_vector[0]);
    103         string end_key = string(range_vector[1]);
    104         if (begin_key.empty() || end_key.empty()) {
    105           LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
    106           return;
    107         }
    108         if (!far_reader->Find(begin_key)) {
    109           LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key;
    110           return;
    111         }
    112         for ( ; !far_reader->Done(); far_reader->Next(), ++i) {
    113           string ikey = far_reader->GetKey();
    114           if (end_key < ikey) break;
    115           const Fst<Arc> &fst = far_reader->GetFst();
    116           FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i,
    117                       filename_prefix, filename_suffix);
    118         }
    119       } else {
    120         LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
    121         return;
    122       }
    123       delete key_cstr;
    124     }
    125     delete keys_cstr;
    126     return;
    127   }
    128   // Nothing specified: extract everything.
    129   for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
    130     string key = far_reader->GetKey();
    131     const Fst<Arc> &fst = far_reader->GetFst();
    132     FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
    133                 filename_prefix, filename_suffix);
    134   }
    135   return;
    136 }
    137 
    138 }  // namespace fst
    139 
    140 #endif  // FST_EXTENSIONS_FAR_EXTRACT_H__
    141