Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
     18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
     19 
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "common/feature-extractor.h"
     25 #include "common/task-context.h"
     26 #include "common/workspace.h"
     27 #include "util/base/logging.h"
     28 #include "util/base/macros.h"
     29 
     30 namespace libtextclassifier {
     31 namespace nlp_core {
     32 
     33 // An EmbeddingFeatureExtractor manages the extraction of features for
     34 // embedding-based models. It wraps a sequence of underlying classes of feature
     35 // extractors, along with associated predicate maps. Each class of feature
     36 // extractors is associated with a name, e.g., "words", "labels", "tags".
     37 //
     38 // The class is split between a generic abstract version,
     39 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
     40 // signature of the ExtractFeatures method) and a typed version.
     41 //
     42 // The predicate maps must be initialized before use: they can be loaded using
     43 // Read() or updated via UpdateMapsForExample.
     44 class GenericEmbeddingFeatureExtractor {
     45  public:
     46   GenericEmbeddingFeatureExtractor() {}
     47   virtual ~GenericEmbeddingFeatureExtractor() {}
     48 
     49   // Get the prefix std::string to put in front of all arguments, so they don't
     50   // conflict with other embedding models.
     51   virtual const std::string ArgPrefix() const = 0;
     52 
     53   // Initializes predicate maps and embedding space names that are common for
     54   // all embedding-based feature extractors.
     55   virtual bool Init(TaskContext *context);
     56 
     57   // Requests workspace for the underlying feature extractors. This is
     58   // implemented in the typed class.
     59   virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
     60 
     61   // Returns number of embedding spaces.
     62   int NumEmbeddings() const { return embedding_dims_.size(); }
     63 
     64   // Number of predicates for the embedding at a given index (vocabulary size).
     65   // Returns -1 if index is out of bounds.
     66   int EmbeddingSize(int index) const {
     67     const GenericFeatureExtractor *extractor = generic_feature_extractor(index);
     68     return (extractor == nullptr) ? -1 : extractor->GetDomainSize();
     69   }
     70 
     71   // Returns the dimensionality of the embedding space.
     72   int EmbeddingDims(int index) const { return embedding_dims_[index]; }
     73 
     74   // Accessor for embedding dims (dimensions of the embedding spaces).
     75   const std::vector<int> &embedding_dims() const { return embedding_dims_; }
     76 
     77   const std::vector<std::string> &embedding_fml() const {
     78     return embedding_fml_;
     79   }
     80 
     81   // Get parameter name by concatenating the prefix and the original name.
     82   std::string GetParamName(const std::string &param_name) const {
     83     std::string full_name = ArgPrefix();
     84     full_name.push_back('_');
     85     full_name.append(param_name);
     86     return full_name;
     87   }
     88 
     89  protected:
     90   // Provides the generic class with access to the templated extractors. This is
     91   // used to get the type information out of the feature extractor without
     92   // knowing the specific calling arguments of the extractor itself.
     93   // Returns nullptr for an out-of-bounds idx.
     94   virtual const GenericFeatureExtractor *generic_feature_extractor(
     95       int idx) const = 0;
     96 
     97  private:
     98   // Embedding space names for parameter sharing.
     99   std::vector<std::string> embedding_names_;
    100 
    101   // FML strings for each feature extractor.
    102   std::vector<std::string> embedding_fml_;
    103 
    104   // Size of each of the embedding spaces (maximum predicate id).
    105   std::vector<int> embedding_sizes_;
    106 
    107   // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
    108   std::vector<int> embedding_dims_;
    109 
    110   TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor);
    111 };
    112 
    113 // Templated, object-specific implementation of the
    114 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
    115 // ARGS...> class that has the appropriate FeatureTraits() to ensure that
    116 // locator type features work.
    117 //
    118 // Note: for backwards compatibility purposes, this always reads the FML spec
    119 // from "<prefix>_features".
    120 template <class EXTRACTOR, class OBJ, class... ARGS>
    121 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
    122  public:
    123   // Initializes all predicate maps, feature extractors, etc.
    124   bool Init(TaskContext *context) override {
    125     if (!GenericEmbeddingFeatureExtractor::Init(context)) {
    126       return false;
    127     }
    128     feature_extractors_.resize(embedding_fml().size());
    129     for (int i = 0; i < embedding_fml().size(); ++i) {
    130       feature_extractors_[i].reset(new EXTRACTOR());
    131       if (!feature_extractors_[i]->Parse(embedding_fml()[i])) {
    132         return false;
    133       }
    134       if (!feature_extractors_[i]->Setup(context)) {
    135         return false;
    136       }
    137     }
    138     for (auto &feature_extractor : feature_extractors_) {
    139       if (!feature_extractor->Init(context)) {
    140         return false;
    141       }
    142     }
    143     return true;
    144   }
    145 
    146   // Requests workspaces from the registry. Must be called after Init(), and
    147   // before Preprocess().
    148   void RequestWorkspaces(WorkspaceRegistry *registry) override {
    149     for (auto &feature_extractor : feature_extractors_) {
    150       feature_extractor->RequestWorkspaces(registry);
    151     }
    152   }
    153 
    154   // Must be called on the object one state for each sentence, before any
    155   // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
    156   void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
    157     for (auto &feature_extractor : feature_extractors_) {
    158       feature_extractor->Preprocess(workspaces, obj);
    159     }
    160   }
    161 
    162   // Extracts features using the extractors. Note that features must already
    163   // be initialized to the correct number of feature extractors. No predicate
    164   // mapping is applied.
    165   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
    166                        ARGS... args,
    167                        std::vector<FeatureVector> *features) const {
    168     TC_DCHECK(features != nullptr);
    169     TC_DCHECK_EQ(features->size(), feature_extractors_.size());
    170     for (int i = 0; i < feature_extractors_.size(); ++i) {
    171       (*features)[i].clear();
    172       feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
    173                                               &(*features)[i]);
    174     }
    175   }
    176 
    177  protected:
    178   // Provides generic access to the feature extractors.
    179   const GenericFeatureExtractor *generic_feature_extractor(
    180       int idx) const override {
    181     if ((idx < 0) || (idx >= feature_extractors_.size())) {
    182       TC_LOG(ERROR) << "Out of bounds index " << idx;
    183       TC_DCHECK(false);  // Crash in debug mode.
    184       return nullptr;
    185     }
    186     return feature_extractors_[idx].get();
    187   }
    188 
    189  private:
    190   // Templated feature extractor class.
    191   std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
    192 };
    193 
    194 }  // namespace nlp_core
    195 }  // namespace libtextclassifier
    196 
    197 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
    198