Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Feature processing for FFModel (feed-forward SmartSelection model).
     18 
     19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
     20 #define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
     21 
     22 #include <memory>
     23 #include <string>
     24 #include <vector>
     25 
     26 #include "smartselect/cached-features.h"
     27 #include "smartselect/text-classification-model.pb.h"
     28 #include "smartselect/token-feature-extractor.h"
     29 #include "smartselect/tokenizer.h"
     30 #include "smartselect/types.h"
     31 #include "util/base/logging.h"
     32 #include "util/utf8/unicodetext.h"
     33 
     34 namespace libtextclassifier {
     35 
     36 constexpr int kInvalidLabel = -1;
     37 
     38 // Maps a vector of sparse features and a vector of dense features to a vector
     39 // of features that combines both.
     40 // The output is written to the memory location pointed to  by the last float*
     41 // argument.
     42 // Returns true on success false on failure.
     43 using FeatureVectorFn = std::function<bool(const std::vector<int>&,
     44                                            const std::vector<float>&, float*)>;
     45 
     46 namespace internal {
     47 
     48 // Parses the serialized protocol buffer.
     49 FeatureProcessorOptions ParseSerializedOptions(
     50     const std::string& serialized_options);
     51 
     52 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
     53     const FeatureProcessorOptions& options);
     54 
     55 // Removes tokens that are not part of a line of the context which contains
     56 // given span.
     57 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
     58                                std::vector<Token>* tokens);
     59 
     60 // Splits tokens that contain the selection boundary inside them.
     61 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
     62 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
     63                                       std::vector<Token>* tokens);
     64 
     65 // Returns the index of token that corresponds to the codepoint span.
     66 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
     67 
     68 // Returns the index of token that corresponds to the middle of the  codepoint
     69 // span.
     70 int CenterTokenFromMiddleOfSelection(
     71     CodepointSpan span, const std::vector<Token>& selectable_tokens);
     72 
     73 // Strips the tokens from the tokens vector that are not used for feature
     74 // extraction because they are out of scope, or pads them so that there is
     75 // enough tokens in the required context_size for all inferences with a click
     76 // in relative_click_span.
     77 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
     78                       std::vector<Token>* tokens, int* click_pos);
     79 
     80 }  // namespace internal
     81 
     82 // Converts a codepoint span to a token span in the given list of tokens.
     83 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
     84                                    CodepointSpan codepoint_span);
     85 
     86 // Converts a token span to a codepoint span in the given list of tokens.
     87 CodepointSpan TokenSpanToCodepointSpan(
     88     const std::vector<Token>& selectable_tokens, TokenSpan token_span);
     89 
     90 // Takes care of preparing features for the span prediction model.
     91 class FeatureProcessor {
     92  public:
     93   explicit FeatureProcessor(const FeatureProcessorOptions& options)
     94       : feature_extractor_(
     95             internal::BuildTokenFeatureExtractorOptions(options)),
     96         options_(options),
     97         tokenizer_({options.tokenization_codepoint_config().begin(),
     98                     options.tokenization_codepoint_config().end()}) {
     99     MakeLabelMaps();
    100     PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
    101                             options.supported_codepoint_ranges().end()},
    102                            &supported_codepoint_ranges_);
    103     PrepareCodepointRanges(
    104         {options.internal_tokenizer_codepoint_ranges().begin(),
    105          options.internal_tokenizer_codepoint_ranges().end()},
    106         &internal_tokenizer_codepoint_ranges_);
    107   }
    108 
    109   explicit FeatureProcessor(const std::string& serialized_options)
    110       : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
    111   }
    112 
    113   // Tokenizes the input string using the selected tokenization method.
    114   std::vector<Token> Tokenize(const std::string& utf8_text) const;
    115 
    116   // Converts a label into a token span.
    117   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
    118 
    119   // Gets the total number of selection labels.
    120   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
    121 
    122   // Gets the string value for given collection label.
    123   std::string LabelToCollection(int label) const;
    124 
    125   // Gets the total number of collections of the model.
    126   int NumCollections() const { return collection_to_label_.size(); }
    127 
    128   // Gets the name of the default collection.
    129   std::string GetDefaultCollection() const;
    130 
    131   const FeatureProcessorOptions& GetOptions() const { return options_; }
    132 
    133   // Tokenizes the context and input span, and finds the click position.
    134   void TokenizeAndFindClick(const std::string& context,
    135                             CodepointSpan input_span,
    136                             std::vector<Token>* tokens, int* click_pos) const;
    137 
    138   // Extracts features as a CachedFeatures object that can be used for repeated
    139   // inference over token spans in the given context.
    140   bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
    141                        TokenSpan relative_click_span,
    142                        const FeatureVectorFn& feature_vector_fn,
    143                        int feature_vector_size, std::vector<Token>* tokens,
    144                        int* click_pos,
    145                        std::unique_ptr<CachedFeatures>* cached_features) const;
    146 
    147   // Fills selection_label_spans with CodepointSpans that correspond to the
    148   // selection labels. The CodepointSpans are based on the codepoint ranges of
    149   // given tokens.
    150   bool SelectionLabelSpans(
    151       VectorSpan<Token> tokens,
    152       std::vector<CodepointSpan>* selection_label_spans) const;
    153 
    154   int DenseFeaturesCount() const {
    155     return feature_extractor_.DenseFeaturesCount();
    156   }
    157 
    158  protected:
    159   // Represents a codepoint range [start, end).
    160   struct CodepointRange {
    161     int32 start;
    162     int32 end;
    163 
    164     CodepointRange(int32 arg_start, int32 arg_end)
    165         : start(arg_start), end(arg_end) {}
    166   };
    167 
    168   // Returns the class id corresponding to the given string collection
    169   // identifier. There is a catch-all class id that the function returns for
    170   // unknown collections.
    171   int CollectionToLabel(const std::string& collection) const;
    172 
    173   // Prepares mapping from collection names to labels.
    174   void MakeLabelMaps();
    175 
    176   // Gets the number of spannable tokens for the model.
    177   //
    178   // Spannable tokens are those tokens of context, which the model predicts
    179   // selection spans over (i.e., there is 1:1 correspondence between the output
    180   // classes of the model and each of the spannable tokens).
    181   int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
    182 
    183   // Converts a label into a span of codepoint indices corresponding to it
    184   // given output_tokens.
    185   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
    186                    CodepointSpan* span) const;
    187 
    188   // Converts a span to the corresponding label given output_tokens.
    189   bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
    190                    const std::vector<Token>& output_tokens, int* label) const;
    191 
    192   // Converts a token span to the corresponding label.
    193   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
    194 
    195   void PrepareCodepointRanges(
    196       const std::vector<FeatureProcessorOptions::CodepointRange>&
    197           codepoint_ranges,
    198       std::vector<CodepointRange>* prepared_codepoint_ranges);
    199 
    200   // Returns the ratio of supported codepoints to total number of codepoints in
    201   // the input context around given click position.
    202   float SupportedCodepointsRatio(int click_pos,
    203                                  const std::vector<Token>& tokens) const;
    204 
    205   // Returns true if given codepoint is covered by the given sorted vector of
    206   // codepoint ranges.
    207   bool IsCodepointInRanges(
    208       int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
    209 
    210   // Finds the center token index in tokens vector, using the method defined
    211   // in options_.
    212   int FindCenterToken(CodepointSpan span,
    213                       const std::vector<Token>& tokens) const;
    214 
    215   // Tokenizes the input text using ICU tokenizer.
    216   bool ICUTokenize(const std::string& context,
    217                    std::vector<Token>* result) const;
    218 
    219   // Takes the result of ICU tokenization and retokenizes stretches of tokens
    220   // made of a specific subset of characters using the internal tokenizer.
    221   void InternalRetokenize(const std::string& context,
    222                           std::vector<Token>* tokens) const;
    223 
    224   // Tokenizes a substring of the unicode string, appending the resulting tokens
    225   // to the output vector. The resulting tokens have bounds relative to the full
    226   // string. Does nothing if the start of the span is negative.
    227   void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
    228                          std::vector<Token>* result) const;
    229 
    230   const TokenFeatureExtractor feature_extractor_;
    231 
    232   // Codepoint ranges that define what codepoints are supported by the model.
    233   // NOTE: Must be sorted.
    234   std::vector<CodepointRange> supported_codepoint_ranges_;
    235 
    236   // Codepoint ranges that define which tokens (consisting of which codepoints)
    237   // should be re-tokenized with the internal tokenizer in the mixed
    238   // tokenization mode.
    239   // NOTE: Must be sorted.
    240   std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
    241 
    242  private:
    243   const FeatureProcessorOptions options_;
    244 
    245   // Mapping between token selection spans and labels ids.
    246   std::map<TokenSpan, int> selection_to_label_;
    247   std::vector<TokenSpan> label_to_selection_;
    248 
    249   // Mapping between collections and labels.
    250   std::map<std::string, int> collection_to_label_;
    251 
    252   Tokenizer tokenizer_;
    253 };
    254 
    255 }  // namespace libtextclassifier
    256 
    257 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
    258