Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
     18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
     19 
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "lang_id/common/fel/feature-extractor.h"
     25 #include "lang_id/common/fel/task-context.h"
     26 #include "lang_id/common/fel/workspace.h"
     27 #include "lang_id/common/lite_base/attributes.h"
     28 
     29 namespace libtextclassifier3 {
     30 namespace mobile {
     31 
     32 // An EmbeddingFeatureExtractor manages the extraction of features for
     33 // embedding-based models. It wraps a sequence of underlying classes of feature
     34 // extractors, along with associated predicate maps. Each class of feature
     35 // extractors is associated with a name, e.g., "words", "labels", "tags".
     36 //
     37 // The class is split between a generic abstract version,
     38 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
     39 // signature of the ExtractFeatures method) and a typed version.
     40 //
     41 // The predicate maps must be initialized before use: they can be loaded using
     42 // Read() or updated via UpdateMapsForExample.
     43 class GenericEmbeddingFeatureExtractor {
     44  public:
     45   // Constructs this GenericEmbeddingFeatureExtractor.
     46   //
     47   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
     48   // avoid name clashes.  See GetParamName().
     49   explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
     50       : arg_prefix_(arg_prefix) {}
     51 
     52   virtual ~GenericEmbeddingFeatureExtractor() {}
     53 
     54   // Sets/inits up predicate maps and embedding space names that are common for
     55   // all embedding based feature extractors.
     56   //
     57   // Returns true on success, false otherwise.
     58   SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
     59   SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
     60 
     61   // Requests workspace for the underlying feature extractors. This is
     62   // implemented in the typed class.
     63   virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
     64 
     65   // Returns number of embedding spaces.
     66   int NumEmbeddings() const { return embedding_dims_.size(); }
     67 
     68   const std::vector<string> &embedding_fml() const { return embedding_fml_; }
     69 
     70   // Get parameter name by concatenating the prefix and the original name.
     71   string GetParamName(const string &param_name) const {
     72     string full_name = arg_prefix_;
     73     full_name.push_back('_');
     74     full_name.append(param_name);
     75     return full_name;
     76   }
     77 
     78  private:
     79   // Prefix for TaskContext parameters.
     80   const string arg_prefix_;
     81 
     82   // Embedding space names for parameter sharing.
     83   std::vector<string> embedding_names_;
     84 
     85   // FML strings for each feature extractor.
     86   std::vector<string> embedding_fml_;
     87 
     88   // Size of each of the embedding spaces (maximum predicate id).
     89   std::vector<int> embedding_sizes_;
     90 
     91   // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
     92   std::vector<int> embedding_dims_;
     93 };
     94 
     95 // Templated, object-specific implementation of the
     96 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
     97 // ARGS...> class that has the appropriate FeatureTraits() to ensure that
     98 // locator type features work.
     99 //
    100 // Note: for backwards compatibility purposes, this always reads the FML spec
    101 // from "<prefix>_features".
    102 template <class EXTRACTOR, class OBJ, class... ARGS>
    103 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
    104  public:
    105   // Constructs this EmbeddingFeatureExtractor.
    106   //
    107   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
    108   // avoid name clashes.  See GetParamName().
    109   explicit EmbeddingFeatureExtractor(const string &arg_prefix)
    110       : GenericEmbeddingFeatureExtractor(arg_prefix) {}
    111 
    112   // Sets up all predicate maps, feature extractors, and flags.
    113   SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
    114     if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
    115       return false;
    116     }
    117     feature_extractors_.resize(embedding_fml().size());
    118     for (int i = 0; i < embedding_fml().size(); ++i) {
    119       feature_extractors_[i].reset(new EXTRACTOR());
    120       if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
    121       if (!feature_extractors_[i]->Setup(context)) return false;
    122     }
    123     return true;
    124   }
    125 
    126   // Initializes resources needed by the feature extractors.
    127   SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
    128     if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
    129     for (auto &feature_extractor : feature_extractors_) {
    130       if (!feature_extractor->Init(context)) return false;
    131     }
    132     return true;
    133   }
    134 
    135   // Requests workspaces from the registry. Must be called after Init(), and
    136   // before Preprocess().
    137   void RequestWorkspaces(WorkspaceRegistry *registry) override {
    138     for (auto &feature_extractor : feature_extractors_) {
    139       feature_extractor->RequestWorkspaces(registry);
    140     }
    141   }
    142 
    143   // Must be called on the object one state for each sentence, before any
    144   // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
    145   void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
    146     for (auto &feature_extractor : feature_extractors_) {
    147       feature_extractor->Preprocess(workspaces, obj);
    148     }
    149   }
    150 
    151   // Extracts features using the extractors. Note that features must already
    152   // be initialized to the correct number of feature extractors. No predicate
    153   // mapping is applied.
    154   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
    155                        ARGS... args,
    156                        std::vector<FeatureVector> *features) const {
    157     // DCHECK(features != nullptr);
    158     // DCHECK_EQ(features->size(), feature_extractors_.size());
    159     for (int i = 0; i < feature_extractors_.size(); ++i) {
    160       (*features)[i].clear();
    161       feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
    162                                               &(*features)[i]);
    163     }
    164   }
    165 
    166  private:
    167   // Templated feature extractor class.
    168   std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
    169 };
    170 
    171 }  // namespace mobile
    172 }  // namespace nlp_saft
    173 
    174 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
    175