Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
     18 #define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
     19 
     20 #include <memory>
     21 #include <vector>
     22 
     23 #include "model-executor.h"
     24 #include "model_generated.h"
     25 #include "types.h"
     26 
     27 namespace libtextclassifier2 {
     28 
     29 // Holds state for extracting features across multiple calls and reusing them.
     30 // Assumes that features for each Token are independent.
     31 class CachedFeatures {
     32  public:
     33   static std::unique_ptr<CachedFeatures> Create(
     34       const TokenSpan& extraction_span,
     35       std::unique_ptr<std::vector<float>> features,
     36       std::unique_ptr<std::vector<float>> padding_features,
     37       const FeatureProcessorOptions* options, int feature_vector_size);
     38 
     39   // Appends the click context features for the given click position to
     40   // 'output_features'.
     41   void AppendClickContextFeaturesForClick(
     42       int click_pos, std::vector<float>* output_features) const;
     43 
     44   // Appends the bounds-sensitive features for the given token span to
     45   // 'output_features'.
     46   void AppendBoundsSensitiveFeaturesForSpan(
     47       TokenSpan selected_span, std::vector<float>* output_features) const;
     48 
     49   // Returns number of features that 'AppendFeaturesForSpan' appends.
     50   int OutputFeaturesSize() const { return output_features_size_; }
     51 
     52  private:
     53   CachedFeatures() {}
     54 
     55   // Appends token features to the output. The intended_span specifies which
     56   // tokens' features should be used in principle. The read_mask_span restricts
     57   // which tokens are actually read. For tokens outside of the read_mask_span,
     58   // padding tokens are used instead.
     59   void AppendFeaturesInternal(const TokenSpan& intended_span,
     60                               const TokenSpan& read_mask_span,
     61                               std::vector<float>* output_features) const;
     62 
     63   // Appends features of one padding token to the output.
     64   void AppendPaddingFeatures(std::vector<float>* output_features) const;
     65 
     66   // Appends the features of tokens from the given span to the output. The
     67   // features are averaged so that the appended features have the size
     68   // corresponding to one token.
     69   void AppendBagFeatures(const TokenSpan& bag_span,
     70                          std::vector<float>* output_features) const;
     71 
     72   int NumFeaturesPerToken() const;
     73 
     74   TokenSpan extraction_span_;
     75   const FeatureProcessorOptions* options_;
     76   int output_features_size_;
     77   std::unique_ptr<std::vector<float>> features_;
     78   std::unique_ptr<std::vector<float>> padding_features_;
     79 };
     80 
     81 }  // namespace libtextclassifier2
     82 
     83 #endif  // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_
     84