Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
     18 #define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
     19 
     20 #include <memory>
     21 #include <vector>
     22 
     23 #include "base.h"
     24 #include "common/vector-span.h"
     25 #include "smartselect/types.h"
     26 
     27 namespace libtextclassifier {
     28 
     29 // Holds state for extracting features across multiple calls and reusing them.
     30 // Assumes that features for each Token are independent.
     31 class CachedFeatures {
     32  public:
     33   // Extracts the features for the given sequence of tokens.
     34   //  - context_size: Specifies how many tokens to the left, and how many
     35   //                   tokens to the right spans the context.
     36   //  - sparse_features, dense_features: Extracted features for each token.
     37   //  - feature_vector_fn: Writes features for given Token to the specified
     38   //                       storage.
     39   //                       NOTE: The function can assume that the underlying
     40   //                       storage is initialized to all zeros.
     41   //  - feature_vector_size: Size of a feature vector for one Token.
     42   CachedFeatures(VectorSpan<Token> tokens, int context_size,
     43                  const std::vector<std::vector<int>>& sparse_features,
     44                  const std::vector<std::vector<float>>& dense_features,
     45                  const std::function<bool(const std::vector<int>&,
     46                                           const std::vector<float>&, float*)>&
     47                      feature_vector_fn,
     48                  int feature_vector_size)
     49       : tokens_(tokens),
     50         context_size_(context_size),
     51         feature_vector_size_(feature_vector_size),
     52         remap_v0_feature_vector_(false),
     53         remap_v0_chargram_embedding_size_(-1) {
     54     Extract(sparse_features, dense_features, feature_vector_fn);
     55   }
     56 
     57   // Gets a VectorSpan with the features for given click position.
     58   bool Get(int click_pos, VectorSpan<float>* features,
     59            VectorSpan<Token>* output_tokens);
     60 
     61   // Turns on a compatibility mode, which re-maps the extracted features to the
     62   // v0 feature format (where the dense features were at the end).
     63   // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
     64   // VectorSpan<float>, so the output of Extract is valid only until the next
     65   // call or destruction of the current CachedFeatures object.
     66   // TODO(zilka): Remove when we'll have retrained models.
     67   void SetV0FeatureMode(int chargram_embedding_size) {
     68     remap_v0_feature_vector_ = true;
     69     remap_v0_chargram_embedding_size_ = chargram_embedding_size;
     70     v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
     71   }
     72 
     73  protected:
     74   // Extracts features for all tokens and stores them for later retrieval.
     75   void Extract(const std::vector<std::vector<int>>& sparse_features,
     76                const std::vector<std::vector<float>>& dense_features,
     77                const std::function<bool(const std::vector<int>&,
     78                                         const std::vector<float>&, float*)>&
     79                    feature_vector_fn);
     80 
     81   // Remaps extracted features to V0 feature format. The mapping is using
     82   // the v0_feature_storage_ as the backing storage for the mapped features.
     83   // For each token the features consist of:
     84   //  - chargram embeddings
     85   //  - dense features
     86   // They are concatenated together as [chargram embeddings; dense features]
     87   // for each token independently.
     88   // The V0 features require that the chargram embeddings for tokens are
     89   // concatenated first together, and at the end, the dense features for the
     90   // tokens are concatenated to it.
     91   void RemapV0FeatureVector(VectorSpan<float>* features);
     92 
     93  private:
     94   const VectorSpan<Token> tokens_;
     95   const int context_size_;
     96   const int feature_vector_size_;
     97   bool remap_v0_feature_vector_;
     98   int remap_v0_chargram_embedding_size_;
     99 
    100   std::vector<float> features_;
    101   std::vector<float> v0_feature_storage_;
    102 };
    103 
    104 }  // namespace libtextclassifier
    105 
    106 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
    107