Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Generic feature extractor for extracting features from objects. The feature
     18 // extractor can be used for extracting features from any object. The feature
     19 // extractor and feature function classes are template classes that have to
     20 // be instantiated for extracting feature from a specific object type.
     21 //
     22 // A feature extractor consists of a hierarchy of feature functions. Each
     23 // feature function extracts one or more feature type and value pairs from the
     24 // object.
     25 //
     26 // The feature extractor has a modular design where new feature functions can be
     27 // registered as components. The feature extractor is initialized from a
     28 // descriptor represented by a protocol buffer. The feature extractor can also
     29 // be initialized from a text-based source specification of the feature
     30 // extractor. Feature specification parsers can be added as components. By
     31 // default the feature extractor can be read from an ASCII protocol buffer or in
     32 // a simple feature modeling language (fml).
     33 
     34 // A feature function is invoked with a focus. Nested feature function can be
     35 // invoked with another focus determined by the parent feature function.
     36 
     37 #ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
     38 #define LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
     39 
     40 #include <stddef.h>
     41 
     42 #include <string>
     43 #include <vector>
     44 
     45 #include "common/feature-descriptors.h"
     46 #include "common/feature-types.h"
     47 #include "common/fml-parser.h"
     48 #include "common/registry.h"
     49 #include "common/task-context.h"
     50 #include "common/workspace.h"
     51 #include "util/base/integral_types.h"
     52 #include "util/base/logging.h"
     53 #include "util/base/macros.h"
     54 #include "util/gtl/stl_util.h"
     55 
     56 namespace libtextclassifier {
     57 namespace nlp_core {
     58 
     59 typedef int64 Predicate;
     60 typedef Predicate FeatureValue;
     61 
     62 // A union used to represent discrete and continuous feature values.
     63 union FloatFeatureValue {
     64  public:
     65   explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
     66   FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
     67   FeatureValue discrete_value;
     68   struct {
     69     uint32 id;
     70     float weight;
     71   };
     72 };
     73 
     74 // A feature vector contains feature type and value pairs.
     75 class FeatureVector {
     76  public:
     77   FeatureVector() {}
     78 
     79   // Adds feature type and value pair to feature vector.
     80   void add(FeatureType *type, FeatureValue value) {
     81     features_.emplace_back(type, value);
     82   }
     83 
     84   // Removes all elements from the feature vector.
     85   void clear() { features_.clear(); }
     86 
     87   // Returns the number of elements in the feature vector.
     88   int size() const { return features_.size(); }
     89 
     90   // Reserves space in the underlying feature vector.
     91   void reserve(int n) { features_.reserve(n); }
     92 
     93   // Returns feature type for an element in the feature vector.
     94   FeatureType *type(int index) const { return features_[index].type; }
     95 
     96   // Returns feature value for an element in the feature vector.
     97   FeatureValue value(int index) const { return features_[index].value; }
     98 
     99  private:
    100   // Structure for holding feature type and value pairs.
    101   struct Element {
    102     Element() : type(nullptr), value(-1) {}
    103     Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
    104 
    105     FeatureType *type;
    106     FeatureValue value;
    107   };
    108 
    109   // Array for storing feature vector elements.
    110   std::vector<Element> features_;
    111 
    112   TC_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
    113 };
    114 
    115 // The generic feature extractor is the type-independent part of a feature
    116 // extractor. This holds the descriptor for the feature extractor and the
    117 // collection of feature types used in the feature extractor.  The feature
    118 // types are not available until FeatureExtractor<>::Init() has been called.
    119 class GenericFeatureExtractor {
    120  public:
    121   GenericFeatureExtractor();
    122   virtual ~GenericFeatureExtractor();
    123 
    124   // Initializes the feature extractor from an FML string specification.  For
    125   // the FML specification grammar, see fml-parser.h.
    126   //
    127   // Returns true on success, false on syntax error.
    128   bool Parse(const std::string &source);
    129 
    130   // Returns the feature extractor descriptor.
    131   const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
    132   FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
    133 
    134   // Returns the number of feature types in the feature extractor.  Invalid
    135   // before Init() has been called.
    136   int feature_types() const { return feature_types_.size(); }
    137 
    138   // Returns a feature type used in the extractor.  Invalid before Init() has
    139   // been called.
    140   const FeatureType *feature_type(int index) const {
    141     return feature_types_[index];
    142   }
    143 
    144   // Returns the feature domain size of this feature extractor.
    145   // NOTE: The way that domain size is calculated is, for some, unintuitive. It
    146   // is the largest domain size of any feature type.
    147   FeatureValue GetDomainSize() const;
    148 
    149  protected:
    150   // Initializes the feature types used by the extractor.  Called from
    151   // FeatureExtractor<>::Init().
    152   //
    153   // Returns true on success, false on error.
    154   bool InitializeFeatureTypes();
    155 
    156  private:
    157   // Initializes the top-level feature functions.
    158   virtual bool InitializeFeatureFunctions() = 0;
    159 
    160   // Returns all feature types used by the extractor. The feature types are
    161   // added to the result array.
    162   virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
    163 
    164   // Descriptor for the feature extractor. This is a protocol buffer that
    165   // contains all the information about the feature extractor. The feature
    166   // functions are initialized from the information in the descriptor.
    167   FeatureExtractorDescriptor descriptor_;
    168 
    169   // All feature types used by the feature extractor. The collection of all the
    170   // feature types describes the feature space of the feature set produced by
    171   // the feature extractor.  Not owned.
    172   std::vector<FeatureType *> feature_types_;
    173 
    174   TC_DISALLOW_COPY_AND_ASSIGN(GenericFeatureExtractor);
    175 };
    176 
    177 // The generic feature function is the type-independent part of a feature
    178 // function. Each feature function is associated with the descriptor that it is
    179 // instantiated from.  The feature types associated with this feature function
    180 // will be established by the time FeatureExtractor<>::Init() completes.
    181 class GenericFeatureFunction {
    182  public:
    183   // A feature value that represents the absence of a value.
    184   static constexpr FeatureValue kNone = -1;
    185 
    186   GenericFeatureFunction();
    187   virtual ~GenericFeatureFunction();
    188 
    189   // Sets up the feature function. NB: FeatureTypes of nested functions are not
    190   // guaranteed to be available until Init().
    191   //
    192   // Returns true on success, false on error.
    193   virtual bool Setup(TaskContext *context) { return true; }
    194 
    195   // Initializes the feature function. NB: The FeatureType of this function must
    196   // be established when this method completes.
    197   //
    198   // Returns true on success, false on error.
    199   virtual bool Init(TaskContext *context) { return true; }
    200 
    201   // Requests workspaces from a registry to obtain indices into a WorkspaceSet
    202   // for any Workspace objects used by this feature function. NB: This will be
    203   // called after Init(), so it can depend on resources and arguments.
    204   virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
    205 
    206   // Appends the feature types produced by the feature function to types.  The
    207   // default implementation appends feature_type(), if non-null.  Invalid
    208   // before Init() has been called.
    209   virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
    210 
    211   // Returns the feature type for feature produced by this feature function. If
    212   // the feature function produces features of different types this returns
    213   // null.  Invalid before Init() has been called.
    214   virtual FeatureType *GetFeatureType() const;
    215 
    216   // Returns the name of the registry used for creating the feature function.
    217   // This can be used for checking if two feature functions are of the same
    218   // kind.
    219   virtual const char *RegistryName() const = 0;
    220 
    221   // Returns the value of a named parameter from the feature function
    222   // descriptor.  Returns empty string ("") if parameter is not found.
    223   std::string GetParameter(const std::string &name) const;
    224 
    225   // Returns the int value of a named parameter from the feature function
    226   // descriptor.  Returns default_value if the parameter is not found or if its
    227   // value can't be parsed as an int.
    228   int GetIntParameter(const std::string &name, int default_value) const;
    229 
    230   // Returns the bool value of a named parameter from the feature function
    231   // descriptor.  Returns default_value if the parameter is not found or if its
    232   // value is not "true" or "false".
    233   bool GetBoolParameter(const std::string &name, bool default_value) const;
    234 
    235   // Returns the FML function description for the feature function, i.e. the
    236   // name and parameters without the nested features.
    237   std::string FunctionName() const {
    238     std::string output;
    239     ToFMLFunction(*descriptor_, &output);
    240     return output;
    241   }
    242 
    243   // Returns the prefix for nested feature functions. This is the prefix of this
    244   // feature function concatenated with the feature function name.
    245   std::string SubPrefix() const {
    246     return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
    247   }
    248 
    249   // Returns/sets the feature extractor this function belongs to.
    250   GenericFeatureExtractor *extractor() const { return extractor_; }
    251   void set_extractor(GenericFeatureExtractor *extractor) {
    252     extractor_ = extractor;
    253   }
    254 
    255   // Returns/sets the feature function descriptor.
    256   FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
    257   void set_descriptor(FeatureFunctionDescriptor *descriptor) {
    258     descriptor_ = descriptor;
    259   }
    260 
    261   // Returns a descriptive name for the feature function. The name is taken from
    262   // the descriptor for the feature function. If the name is empty or the
    263   // feature function is a variable the name is the FML representation of the
    264   // feature, including the prefix.
    265   std::string name() const;
    266 
    267   // Returns the argument from the feature function descriptor. It defaults to
    268   // 0 if the argument has not been specified.
    269   int argument() const {
    270     return descriptor_->has_argument() ? descriptor_->argument() : 0;
    271   }
    272 
    273   // Returns/sets/clears function name prefix.
    274   const std::string &prefix() const { return prefix_; }
    275   void set_prefix(const std::string &prefix) { prefix_ = prefix; }
    276 
    277  protected:
    278   // Returns the feature type for single-type feature functions.
    279   FeatureType *feature_type() const { return feature_type_; }
    280 
    281   // Sets the feature type for single-type feature functions.  This takes
    282   // ownership of feature_type.  Can only be called once with a non-null
    283   // pointer.
    284   void set_feature_type(FeatureType *feature_type) {
    285     TC_DCHECK_NE(feature_type, nullptr);
    286     feature_type_ = feature_type;
    287   }
    288 
    289  private:
    290   // Feature extractor this feature function belongs to.  Not owned.
    291   GenericFeatureExtractor *extractor_ = nullptr;
    292 
    293   // Descriptor for feature function.  Not owned.
    294   FeatureFunctionDescriptor *descriptor_ = nullptr;
    295 
    296   // Feature type for features produced by this feature function. If the
    297   // feature function produces features of multiple feature types this is null
    298   // and the feature function must return it's feature types in
    299   // GetFeatureTypes().  Owned.
    300   FeatureType *feature_type_ = nullptr;
    301 
    302   // Prefix used for sub-feature types of this function.
    303   std::string prefix_;
    304 };
    305 
    306 // Feature function that can extract features from an object.  Templated on
    307 // two type arguments:
    308 //
    309 // OBJ:  The "object" from which features are extracted; e.g., a sentence.  This
    310 //       should be a plain type, rather than a reference or pointer.
    311 //
    312 // ARGS: A set of 0 or more types that are used to "index" into some part of the
    313 //       object that should be extracted, e.g. an int token index for a sentence
    314 //       object.  This should not be a reference type.
    315 template <class OBJ, class... ARGS>
    316 class FeatureFunction
    317     : public GenericFeatureFunction,
    318       public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
    319  public:
    320   using Self = FeatureFunction<OBJ, ARGS...>;
    321 
    322   // Preprocesses the object.  This will be called prior to calling Evaluate()
    323   // or Compute() on that object.
    324   virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {}
    325 
    326   // Appends features computed from the object and focus to the result.  The
    327   // default implementation delegates to Compute(), adding a single value if
    328   // available.  Multi-valued feature functions must override this method.
    329   virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
    330                         ARGS... args, FeatureVector *result) const {
    331     FeatureValue value = Compute(workspaces, object, args..., result);
    332     if (value != kNone) result->add(feature_type(), value);
    333   }
    334 
    335   // Returns a feature value computed from the object and focus, or kNone if no
    336   // value is computed.  Single-valued feature functions only need to override
    337   // this method.
    338   virtual FeatureValue Compute(const WorkspaceSet &workspaces,
    339                                const OBJ &object, ARGS... args,
    340                                const FeatureVector *fv) const {
    341     return kNone;
    342   }
    343 
    344   // Instantiates a new feature function in a feature extractor from a feature
    345   // descriptor.
    346   static Self *Instantiate(GenericFeatureExtractor *extractor,
    347                            FeatureFunctionDescriptor *fd,
    348                            const std::string &prefix) {
    349     Self *f = Self::Create(fd->type());
    350     if (f != nullptr) {
    351       f->set_extractor(extractor);
    352       f->set_descriptor(fd);
    353       f->set_prefix(prefix);
    354     }
    355     return f;
    356   }
    357 
    358   // Returns the name of the registry for the feature function.
    359   const char *RegistryName() const override { return Self::registry()->name(); }
    360 
    361  private:
    362   // Special feature function class for resolving variable references. The type
    363   // of the feature function is used for resolving the variable reference. When
    364   // evaluated it will either get the feature value(s) from the variable portion
    365   // of the feature vector, if present, or otherwise it will call the referenced
    366   // feature extractor function directly to extract the feature(s).
    367   class Reference;
    368 };
    369 
    370 // Base class for features with nested feature functions. The nested functions
    371 // are of type NES, which may be different from the type of the parent function.
    372 // NB: NestedFeatureFunction will ensure that all initialization of nested
    373 // functions takes place during Setup() and Init() -- after the nested features
    374 // are initialized, the parent feature is initialized via SetupNested() and
    375 // InitNested(). Alternatively, a derived classes that overrides Setup() and
    376 // Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
    377 //
    378 // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
    379 // Compute, since the nested functions may be of a different type.
    380 template <class NES, class OBJ, class... ARGS>
    381 class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
    382  public:
    383   using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
    384 
    385   // Clean up nested functions.
    386   ~NestedFeatureFunction() override {
    387     // Fully qualified class name, to avoid an ambiguity error when building for
    388     // Android.
    389     ::libtextclassifier::STLDeleteElements(&nested_);
    390   }
    391 
    392   // By default, just appends the nested feature types.
    393   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
    394     // It's odd if a NestedFeatureFunction does not have anything nested inside
    395     // it, so we crash in debug mode.  Still, nothing should crash in prod mode.
    396     TC_DCHECK(!this->nested().empty())
    397         << "Nested features require nested features to be defined.";
    398     for (auto *function : nested_) function->GetFeatureTypes(types);
    399   }
    400 
    401   // Sets up the nested features.
    402   bool Setup(TaskContext *context) override {
    403     bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
    404                                 this->SubPrefix());
    405     if (!success) {
    406       return false;
    407     }
    408     for (auto *function : nested_) {
    409       if (!function->Setup(context)) return false;
    410     }
    411     if (!SetupNested(context)) {
    412       return false;
    413     }
    414     return true;
    415   }
    416 
    417   // Sets up this NestedFeatureFunction specifically.
    418   virtual bool SetupNested(TaskContext *context) { return true; }
    419 
    420   // Initializes the nested features.
    421   bool Init(TaskContext *context) override {
    422     for (auto *function : nested_) {
    423       if (!function->Init(context)) return false;
    424     }
    425     if (!InitNested(context)) return false;
    426     return true;
    427   }
    428 
    429   // Initializes this NestedFeatureFunction specifically.
    430   virtual bool InitNested(TaskContext *context) { return true; }
    431 
    432   // Gets all the workspaces needed for the nested functions.
    433   void RequestWorkspaces(WorkspaceRegistry *registry) override {
    434     for (auto *function : nested_) function->RequestWorkspaces(registry);
    435   }
    436 
    437   // Returns the list of nested feature functions.
    438   const std::vector<NES *> &nested() const { return nested_; }
    439 
    440   // Instantiates nested feature functions for a feature function. Creates and
    441   // initializes one feature function for each sub-descriptor in the feature
    442   // descriptor.
    443   static bool CreateNested(GenericFeatureExtractor *extractor,
    444                            FeatureFunctionDescriptor *fd,
    445                            std::vector<NES *> *functions,
    446                            const std::string &prefix) {
    447     for (int i = 0; i < fd->feature_size(); ++i) {
    448       FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
    449       NES *f = NES::Instantiate(extractor, sub, prefix);
    450       if (f == nullptr) {
    451         return false;
    452       }
    453       functions->push_back(f);
    454     }
    455     return true;
    456   }
    457 
    458  protected:
    459   // The nested feature functions, if any, in order of declaration in the
    460   // feature descriptor.  Owned.
    461   std::vector<NES *> nested_;
    462 };
    463 
    464 // Base class for a nested feature function that takes nested features with the
    465 // same signature as these features, i.e. a meta feature. For this class, we can
    466 // provide preprocessing of the nested features.
    467 template <class OBJ, class... ARGS>
    468 class MetaFeatureFunction
    469     : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
    470                                    ARGS...> {
    471  public:
    472   // Preprocesses using the nested features.
    473   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
    474     for (auto *function : this->nested_) {
    475       function->Preprocess(workspaces, object);
    476     }
    477   }
    478 };
    479 
    480 // Template for a special type of locator: The locator of type
    481 // FeatureFunction<OBJ, ARGS...> calls nested functions of type
    482 // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
    483 // responsible for translating by providing the following:
    484 //
    485 // // Gets the new additional focus.
    486 // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
    487 //
    488 // This is useful to e.g. add a token focus to a parser state based on some
    489 // desired property of that state.
    490 template <class DER, class OBJ, class IDX, class... ARGS>
    491 class FeatureAddFocusLocator
    492     : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
    493                                    ARGS...> {
    494  public:
    495   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
    496     for (auto *function : this->nested_) {
    497       function->Preprocess(workspaces, object);
    498     }
    499   }
    500 
    501   void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
    502                 FeatureVector *result) const override {
    503     IDX focus =
    504         static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
    505     for (auto *function : this->nested()) {
    506       function->Evaluate(workspaces, object, focus, args..., result);
    507     }
    508   }
    509 
    510   // Returns the first nested feature's computed value.
    511   FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
    512                        ARGS... args,
    513                        const FeatureVector *result) const override {
    514     IDX focus =
    515         static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
    516     return this->nested()[0]->Compute(workspaces, object, focus, args...,
    517                                       result);
    518   }
    519 };
    520 
    521 // CRTP feature locator class. This is a meta feature that modifies ARGS and
    522 // then calls the nested feature functions with the modified ARGS. Note that in
    523 // order for this template to work correctly, all of ARGS must be types for
    524 // which the reference operator & can be interpreted as a pointer to the
    525 // argument. The derived class DER must implement the UpdateFocus method which
    526 // takes pointers to the ARGS arguments:
    527 //
    528 // // Updates the current arguments.
    529 // void UpdateArgs(const OBJ &object, ARGS *...args) const;
    530 template <class DER, class OBJ, class... ARGS>
    531 class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
    532  public:
    533   // Feature locators have an additional check that there is no intrinsic type,
    534   // but only in debug mode: having an intrinsic type here is odd, but not
    535   // enough to motive a crash in prod.
    536   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
    537     TC_DCHECK_EQ(this->feature_type(), nullptr)
    538         << "FeatureLocators should not have an intrinsic type.";
    539     MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
    540   }
    541 
    542   // Evaluates the locator.
    543   void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
    544                 FeatureVector *result) const override {
    545     static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
    546     for (auto *function : this->nested()) {
    547       function->Evaluate(workspaces, object, args..., result);
    548     }
    549   }
    550 
    551   // Returns the first nested feature's computed value.
    552   FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
    553                        ARGS... args,
    554                        const FeatureVector *result) const override {
    555     static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
    556     return this->nested()[0]->Compute(workspaces, object, args..., result);
    557   }
    558 };
    559 
    560 // Feature extractor for extracting features from objects of a certain class.
    561 // Template type parameters are as defined for FeatureFunction.
    562 template <class OBJ, class... ARGS>
    563 class FeatureExtractor : public GenericFeatureExtractor {
    564  public:
    565   // Feature function type for top-level functions in the feature extractor.
    566   typedef FeatureFunction<OBJ, ARGS...> Function;
    567   typedef FeatureExtractor<OBJ, ARGS...> Self;
    568 
    569   // Feature locator type for the feature extractor.
    570   template <class DER>
    571   using Locator = FeatureLocator<DER, OBJ, ARGS...>;
    572 
    573   // Initializes feature extractor.
    574   FeatureExtractor() {}
    575 
    576   ~FeatureExtractor() override {
    577     // Fully qualified class name, to avoid an ambiguity error when building for
    578     // Android.
    579     ::libtextclassifier::STLDeleteElements(&functions_);
    580   }
    581 
    582   // Sets up the feature extractor. Note that only top-level functions exist
    583   // until Setup() is called. This does not take ownership over the context,
    584   // which must outlive this.
    585   bool Setup(TaskContext *context) {
    586     for (Function *function : functions_) {
    587       if (!function->Setup(context)) return false;
    588     }
    589     return true;
    590   }
    591 
    592   // Initializes the feature extractor.  Must be called after Setup().  This
    593   // does not take ownership over the context, which must outlive this.
    594   bool Init(TaskContext *context) {
    595     for (Function *function : functions_) {
    596       if (!function->Init(context)) return false;
    597     }
    598     if (!this->InitializeFeatureTypes()) {
    599       return false;
    600     }
    601     return true;
    602   }
    603 
    604   // Requests workspaces from the registry. Must be called after Init(), and
    605   // before Preprocess(). Does not take ownership over registry. This should be
    606   // the same registry used to initialize the WorkspaceSet used in Preprocess()
    607   // and ExtractFeatures(). NB: This is a different ordering from that used in
    608   // SentenceFeatureRepresentation style feature computation.
    609   void RequestWorkspaces(WorkspaceRegistry *registry) {
    610     for (auto *function : functions_) function->RequestWorkspaces(registry);
    611   }
    612 
    613   // Preprocesses the object using feature functions for the phase.  Must be
    614   // called before any calls to ExtractFeatures() on that object and phase.
    615   void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {
    616     for (Function *function : functions_) {
    617       function->Preprocess(workspaces, object);
    618     }
    619   }
    620 
    621   // Extracts features from an object with a focus. This invokes all the
    622   // top-level feature functions in the feature extractor. Only feature
    623   // functions belonging to the specified phase are invoked.
    624   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
    625                        ARGS... args, FeatureVector *result) const {
    626     result->reserve(this->feature_types());
    627 
    628     // Extract features.
    629     for (int i = 0; i < functions_.size(); ++i) {
    630       functions_[i]->Evaluate(workspaces, object, args..., result);
    631     }
    632   }
    633 
    634  private:
    635   // Creates and initializes all feature functions in the feature extractor.
    636   bool InitializeFeatureFunctions() override {
    637     // Create all top-level feature functions.
    638     for (int i = 0; i < descriptor().feature_size(); ++i) {
    639       FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i);
    640       Function *function = Function::Instantiate(this, fd, "");
    641       if (function == nullptr) return false;
    642       functions_.push_back(function);
    643     }
    644     return true;
    645   }
    646 
    647   // Collect all feature types used in the feature extractor.
    648   void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
    649     for (Function *function : functions_) {
    650       function->GetFeatureTypes(types);
    651     }
    652   }
    653 
    654   // Top-level feature functions (and variables) in the feature extractor.
    655   // Owned.  INVARIANT: contains only non-null pointers.
    656   std::vector<Function *> functions_;
    657 };
    658 
    659 #define REGISTER_FEATURE_FUNCTION(base, name, component) \
    660   REGISTER_CLASS_COMPONENT(base, name, component)
    661 
    662 }  // namespace nlp_core
    663 }  // namespace libtextclassifier
    664 
    665 #endif  // LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_
    666