Home | History | Annotate | Download | only in annotator
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Feature processing for FFModel (feed-forward SmartSelection model).
     18 
     19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
     20 #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
     21 
     22 #include <map>
     23 #include <memory>
     24 #include <set>
     25 #include <string>
     26 #include <vector>
     27 
     28 #include "annotator/cached-features.h"
     29 #include "annotator/model_generated.h"
     30 #include "annotator/types.h"
     31 #include "utils/base/integral_types.h"
     32 #include "utils/base/logging.h"
     33 #include "utils/token-feature-extractor.h"
     34 #include "utils/tokenizer.h"
     35 #include "utils/utf8/unicodetext.h"
     36 #include "utils/utf8/unilib.h"
     37 
     38 namespace libtextclassifier3 {
     39 
     40 constexpr int kInvalidLabel = -1;
     41 
     42 namespace internal {
     43 
     44 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
     45                          const UniLib* unilib);
     46 
     47 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
     48     const FeatureProcessorOptions* options);
     49 
     50 // Splits tokens that contain the selection boundary inside them.
     51 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
     52 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
     53                                       std::vector<Token>* tokens);
     54 
     55 // Returns the index of token that corresponds to the codepoint span.
     56 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
     57 
     58 // Returns the index of token that corresponds to the middle of the  codepoint
     59 // span.
     60 int CenterTokenFromMiddleOfSelection(
     61     CodepointSpan span, const std::vector<Token>& selectable_tokens);
     62 
     63 // Strips the tokens from the tokens vector that are not used for feature
     64 // extraction because they are out of scope, or pads them so that there is
     65 // enough tokens in the required context_size for all inferences with a click
     66 // in relative_click_span.
     67 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
     68                       std::vector<Token>* tokens, int* click_pos);
     69 
     70 }  // namespace internal
     71 
     72 // Converts a codepoint span to a token span in the given list of tokens.
     73 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a
     74 // token to overlap with the codepoint range to be considered part of it.
     75 // Otherwise it must be fully included in the range.
     76 TokenSpan CodepointSpanToTokenSpan(
     77     const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
     78     bool snap_boundaries_to_containing_tokens = false);
     79 
     80 // Converts a token span to a codepoint span in the given list of tokens.
     81 CodepointSpan TokenSpanToCodepointSpan(
     82     const std::vector<Token>& selectable_tokens, TokenSpan token_span);
     83 
     84 // Takes care of preparing features for the span prediction model.
     85 class FeatureProcessor {
     86  public:
     87   // A cache mapping codepoint spans to embedded tokens features. An instance
     88   // can be provided to multiple calls to ExtractFeatures() operating on the
     89   // same context (the same codepoint spans corresponding to the same tokens),
     90   // as an optimization. Note that the tokenizations do not have to be
     91   // identical.
     92   typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
     93 
     94   FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
     95       : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
     96                            *unilib),
     97         options_(options),
     98         tokenizer_(internal::BuildTokenizer(options, unilib)) {
     99     MakeLabelMaps();
    100     if (options->supported_codepoint_ranges() != nullptr) {
    101       SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
    102                            options->supported_codepoint_ranges()->end()},
    103                           &supported_codepoint_ranges_);
    104     }
    105     PrepareIgnoredSpanBoundaryCodepoints();
    106   }
    107 
    108   // Tokenizes the input string using the selected tokenization method.
    109   std::vector<Token> Tokenize(const std::string& text) const;
    110 
    111   // Same as above but takes UnicodeText.
    112   std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
    113 
    114   // Converts a label into a token span.
    115   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
    116 
    117   // Gets the total number of selection labels.
    118   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
    119 
    120   // Gets the string value for given collection label.
    121   std::string LabelToCollection(int label) const;
    122 
    123   // Gets the total number of collections of the model.
    124   int NumCollections() const { return collection_to_label_.size(); }
    125 
    126   // Gets the name of the default collection.
    127   std::string GetDefaultCollection() const;
    128 
    129   const FeatureProcessorOptions* GetOptions() const { return options_; }
    130 
    131   // Retokenizes the context and input span, and finds the click position.
    132   // Depending on the options, might modify tokens (split them or remove them).
    133   void RetokenizeAndFindClick(const std::string& context,
    134                               CodepointSpan input_span,
    135                               bool only_use_line_with_click,
    136                               std::vector<Token>* tokens, int* click_pos) const;
    137 
    138   // Same as above but takes UnicodeText.
    139   void RetokenizeAndFindClick(const UnicodeText& context_unicode,
    140                               CodepointSpan input_span,
    141                               bool only_use_line_with_click,
    142                               std::vector<Token>* tokens, int* click_pos) const;
    143 
    144   // Returns true if the token span has enough supported codepoints (as defined
    145   // in the model config) or not and model should not run.
    146   bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
    147                                     TokenSpan token_span) const;
    148 
    149   // Extracts features as a CachedFeatures object that can be used for repeated
    150   // inference over token spans in the given context.
    151   bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
    152                        CodepointSpan selection_span_for_feature,
    153                        const EmbeddingExecutor* embedding_executor,
    154                        EmbeddingCache* embedding_cache, int feature_vector_size,
    155                        std::unique_ptr<CachedFeatures>* cached_features) const;
    156 
    157   // Fills selection_label_spans with CodepointSpans that correspond to the
    158   // selection labels. The CodepointSpans are based on the codepoint ranges of
    159   // given tokens.
    160   bool SelectionLabelSpans(
    161       VectorSpan<Token> tokens,
    162       std::vector<CodepointSpan>* selection_label_spans) const;
    163 
    164   int DenseFeaturesCount() const {
    165     return feature_extractor_.DenseFeaturesCount();
    166   }
    167 
    168   int EmbeddingSize() const { return options_->embedding_size(); }
    169 
    170   // Splits context to several segments.
    171   std::vector<UnicodeTextRange> SplitContext(
    172       const UnicodeText& context_unicode) const;
    173 
    174   // Strips boundary codepoints from the span in context and returns the new
    175   // start and end indices. If the span comprises entirely of boundary
    176   // codepoints, the first index of span is returned for both indices.
    177   CodepointSpan StripBoundaryCodepoints(const std::string& context,
    178                                         CodepointSpan span) const;
    179 
    180   // Same as above but takes UnicodeText.
    181   CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
    182                                         CodepointSpan span) const;
    183 
    184   // Same as above but takes a pair of iterators for the span, for efficiency.
    185   CodepointSpan StripBoundaryCodepoints(
    186       const UnicodeText::const_iterator& span_begin,
    187       const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
    188 
    189   // Same as above, but takes an optional buffer for saving the modified value.
    190   // As an optimization, returns pointer to 'value' if nothing was stripped, or
    191   // pointer to 'buffer' if something was stripped.
    192   const std::string& StripBoundaryCodepoints(const std::string& value,
    193                                              std::string* buffer) const;
    194 
    195  protected:
    196   // Returns the class id corresponding to the given string collection
    197   // identifier. There is a catch-all class id that the function returns for
    198   // unknown collections.
    199   int CollectionToLabel(const std::string& collection) const;
    200 
    201   // Prepares mapping from collection names to labels.
    202   void MakeLabelMaps();
    203 
    204   // Gets the number of spannable tokens for the model.
    205   //
    206   // Spannable tokens are those tokens of context, which the model predicts
    207   // selection spans over (i.e., there is 1:1 correspondence between the output
    208   // classes of the model and each of the spannable tokens).
    209   int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
    210 
    211   // Converts a label into a span of codepoint indices corresponding to it
    212   // given output_tokens.
    213   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
    214                    CodepointSpan* span) const;
    215 
    216   // Converts a span to the corresponding label given output_tokens.
    217   bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
    218                    const std::vector<Token>& output_tokens, int* label) const;
    219 
    220   // Converts a token span to the corresponding label.
    221   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
    222 
    223   // Returns the ratio of supported codepoints to total number of codepoints in
    224   // the given token span.
    225   float SupportedCodepointsRatio(const TokenSpan& token_span,
    226                                  const std::vector<Token>& tokens) const;
    227 
    228   void PrepareIgnoredSpanBoundaryCodepoints();
    229 
    230   // Counts the number of span boundary codepoints. If count_from_beginning is
    231   // True, the counting will start at the span_start iterator (inclusive) and at
    232   // maximum end at span_end (exclusive). If count_from_beginning is True, the
    233   // counting will start from span_end (exclusive) and end at span_start
    234   // (inclusive).
    235   int CountIgnoredSpanBoundaryCodepoints(
    236       const UnicodeText::const_iterator& span_start,
    237       const UnicodeText::const_iterator& span_end,
    238       bool count_from_beginning) const;
    239 
    240   // Finds the center token index in tokens vector, using the method defined
    241   // in options_.
    242   int FindCenterToken(CodepointSpan span,
    243                       const std::vector<Token>& tokens) const;
    244 
    245   // Removes all tokens from tokens that are not on a line (defined by calling
    246   // SplitContext on the context) to which span points.
    247   void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
    248                                  std::vector<Token>* tokens) const;
    249 
    250   // Same as above but takes UnicodeText.
    251   void StripTokensFromOtherLines(const UnicodeText& context_unicode,
    252                                  CodepointSpan span,
    253                                  std::vector<Token>* tokens) const;
    254 
    255   // Extracts the features of a token and appends them to the output vector.
    256   // Uses the embedding cache to to avoid re-extracting the re-embedding the
    257   // sparse features for the same token.
    258   bool AppendTokenFeaturesWithCache(const Token& token,
    259                                     CodepointSpan selection_span_for_feature,
    260                                     const EmbeddingExecutor* embedding_executor,
    261                                     EmbeddingCache* embedding_cache,
    262                                     std::vector<float>* output_features) const;
    263 
    264  protected:
    265   const TokenFeatureExtractor feature_extractor_;
    266 
    267   // Codepoint ranges that define what codepoints are supported by the model.
    268   // NOTE: Must be sorted.
    269   std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
    270 
    271  private:
    272   // Set of codepoints that will be stripped from beginning and end of
    273   // predicted spans.
    274   std::set<int32> ignored_span_boundary_codepoints_;
    275 
    276   const FeatureProcessorOptions* const options_;
    277 
    278   // Mapping between token selection spans and labels ids.
    279   std::map<TokenSpan, int> selection_to_label_;
    280   std::vector<TokenSpan> label_to_selection_;
    281 
    282   // Mapping between collections and labels.
    283   std::map<std::string, int> collection_to_label_;
    284 
    285   Tokenizer tokenizer_;
    286 };
    287 
    288 }  // namespace libtextclassifier3
    289 
    290 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
    291