1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ 18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "common/feature-extractor.h" 25 #include "common/task-context.h" 26 #include "common/workspace.h" 27 #include "util/base/logging.h" 28 #include "util/base/macros.h" 29 30 namespace libtextclassifier { 31 namespace nlp_core { 32 33 // An EmbeddingFeatureExtractor manages the extraction of features for 34 // embedding-based models. It wraps a sequence of underlying classes of feature 35 // extractors, along with associated predicate maps. Each class of feature 36 // extractors is associated with a name, e.g., "words", "labels", "tags". 37 // 38 // The class is split between a generic abstract version, 39 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the 40 // signature of the ExtractFeatures method) and a typed version. 41 // 42 // The predicate maps must be initialized before use: they can be loaded using 43 // Read() or updated via UpdateMapsForExample. 44 class GenericEmbeddingFeatureExtractor { 45 public: 46 GenericEmbeddingFeatureExtractor() {} 47 virtual ~GenericEmbeddingFeatureExtractor() {} 48 49 // Get the prefix std::string to put in front of all arguments, so they don't 50 // conflict with other embedding models. 51 virtual const std::string ArgPrefix() const = 0; 52 53 // Initializes predicate maps and embedding space names that are common for 54 // all embedding-based feature extractors. 55 virtual bool Init(TaskContext *context); 56 57 // Requests workspace for the underlying feature extractors. This is 58 // implemented in the typed class. 59 virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0; 60 61 // Returns number of embedding spaces. 62 int NumEmbeddings() const { return embedding_dims_.size(); } 63 64 // Number of predicates for the embedding at a given index (vocabulary size). 65 // Returns -1 if index is out of bounds. 66 int EmbeddingSize(int index) const { 67 const GenericFeatureExtractor *extractor = generic_feature_extractor(index); 68 return (extractor == nullptr) ? -1 : extractor->GetDomainSize(); 69 } 70 71 // Returns the dimensionality of the embedding space. 72 int EmbeddingDims(int index) const { return embedding_dims_[index]; } 73 74 // Accessor for embedding dims (dimensions of the embedding spaces). 75 const std::vector<int> &embedding_dims() const { return embedding_dims_; } 76 77 const std::vector<std::string> &embedding_fml() const { 78 return embedding_fml_; 79 } 80 81 // Get parameter name by concatenating the prefix and the original name. 82 std::string GetParamName(const std::string ¶m_name) const { 83 std::string full_name = ArgPrefix(); 84 full_name.push_back('_'); 85 full_name.append(param_name); 86 return full_name; 87 } 88 89 protected: 90 // Provides the generic class with access to the templated extractors. This is 91 // used to get the type information out of the feature extractor without 92 // knowing the specific calling arguments of the extractor itself. 93 // Returns nullptr for an out-of-bounds idx. 94 virtual const GenericFeatureExtractor *generic_feature_extractor( 95 int idx) const = 0; 96 97 private: 98 // Embedding space names for parameter sharing. 99 std::vector<std::string> embedding_names_; 100 101 // FML strings for each feature extractor. 102 std::vector<std::string> embedding_fml_; 103 104 // Size of each of the embedding spaces (maximum predicate id). 105 std::vector<int> embedding_sizes_; 106 107 // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.) 108 std::vector<int> embedding_dims_; 109 110 TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor); 111 }; 112 113 // Templated, object-specific implementation of the 114 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ, 115 // ARGS...> class that has the appropriate FeatureTraits() to ensure that 116 // locator type features work. 117 // 118 // Note: for backwards compatibility purposes, this always reads the FML spec 119 // from "<prefix>_features". 120 template <class EXTRACTOR, class OBJ, class... ARGS> 121 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { 122 public: 123 // Initializes all predicate maps, feature extractors, etc. 124 bool Init(TaskContext *context) override { 125 if (!GenericEmbeddingFeatureExtractor::Init(context)) { 126 return false; 127 } 128 feature_extractors_.resize(embedding_fml().size()); 129 for (int i = 0; i < embedding_fml().size(); ++i) { 130 feature_extractors_[i].reset(new EXTRACTOR()); 131 if (!feature_extractors_[i]->Parse(embedding_fml()[i])) { 132 return false; 133 } 134 if (!feature_extractors_[i]->Setup(context)) { 135 return false; 136 } 137 } 138 for (auto &feature_extractor : feature_extractors_) { 139 if (!feature_extractor->Init(context)) { 140 return false; 141 } 142 } 143 return true; 144 } 145 146 // Requests workspaces from the registry. Must be called after Init(), and 147 // before Preprocess(). 148 void RequestWorkspaces(WorkspaceRegistry *registry) override { 149 for (auto &feature_extractor : feature_extractors_) { 150 feature_extractor->RequestWorkspaces(registry); 151 } 152 } 153 154 // Must be called on the object one state for each sentence, before any 155 // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures). 156 void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const { 157 for (auto &feature_extractor : feature_extractors_) { 158 feature_extractor->Preprocess(workspaces, obj); 159 } 160 } 161 162 // Extracts features using the extractors. Note that features must already 163 // be initialized to the correct number of feature extractors. No predicate 164 // mapping is applied. 165 void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj, 166 ARGS... args, 167 std::vector<FeatureVector> *features) const { 168 TC_DCHECK(features != nullptr); 169 TC_DCHECK_EQ(features->size(), feature_extractors_.size()); 170 for (int i = 0; i < feature_extractors_.size(); ++i) { 171 (*features)[i].clear(); 172 feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., 173 &(*features)[i]); 174 } 175 } 176 177 protected: 178 // Provides generic access to the feature extractors. 179 const GenericFeatureExtractor *generic_feature_extractor( 180 int idx) const override { 181 if ((idx < 0) || (idx >= feature_extractors_.size())) { 182 TC_LOG(ERROR) << "Out of bounds index " << idx; 183 TC_DCHECK(false); // Crash in debug mode. 184 return nullptr; 185 } 186 return feature_extractors_[idx].get(); 187 } 188 189 private: 190 // Templated feature extractor class. 191 std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_; 192 }; 193 194 } // namespace nlp_core 195 } // namespace libtextclassifier 196 197 #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ 198