Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Inference code for the text classification model.
     18 
     19 #ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
     20 #define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
     21 
     22 #include <memory>
     23 #include <set>
     24 #include <string>
     25 #include <vector>
     26 
     27 #include "datetime/parser.h"
     28 #include "feature-processor.h"
     29 #include "model-executor.h"
     30 #include "model_generated.h"
     31 #include "strip-unpaired-brackets.h"
     32 #include "types.h"
     33 #include "util/memory/mmap.h"
     34 #include "util/utf8/unilib.h"
     35 #include "zlib-utils.h"
     36 
     37 namespace libtextclassifier2 {
     38 
     39 struct SelectionOptions {
     40   // Comma-separated list of locale specification for the input text (BCP 47
     41   // tags).
     42   std::string locales;
     43 
     44   static SelectionOptions Default() { return SelectionOptions(); }
     45 };
     46 
     47 struct ClassificationOptions {
     48   // For parsing relative datetimes, the reference now time against which the
     49   // relative datetimes get resolved.
     50   // UTC milliseconds since epoch.
     51   int64 reference_time_ms_utc = 0;
     52 
     53   // Timezone in which the input text was written (format as accepted by ICU).
     54   std::string reference_timezone;
     55 
     56   // Comma-separated list of locale specification for the input text (BCP 47
     57   // tags).
     58   std::string locales;
     59 
     60   static ClassificationOptions Default() { return ClassificationOptions(); }
     61 };
     62 
     63 struct AnnotationOptions {
     64   // For parsing relative datetimes, the reference now time against which the
     65   // relative datetimes get resolved.
     66   // UTC milliseconds since epoch.
     67   int64 reference_time_ms_utc = 0;
     68 
     69   // Timezone in which the input text was written (format as accepted by ICU).
     70   std::string reference_timezone;
     71 
     72   // Comma-separated list of locale specification for the input text (BCP 47
     73   // tags).
     74   std::string locales;
     75 
     76   static AnnotationOptions Default() { return AnnotationOptions(); }
     77 };
     78 
     79 // Holds TFLite interpreters for selection and classification models.
     80 // NOTE: his class is not thread-safe, thus should NOT be re-used across
     81 // threads.
     82 class InterpreterManager {
     83  public:
     84   // The constructor can be called with nullptr for any of the executors, and is
     85   // a defined behavior, as long as the corresponding *Interpreter() method is
     86   // not called when the executor is null.
     87   InterpreterManager(const ModelExecutor* selection_executor,
     88                      const ModelExecutor* classification_executor)
     89       : selection_executor_(selection_executor),
     90         classification_executor_(classification_executor) {}
     91 
     92   // Gets or creates and caches an interpreter for the selection model.
     93   tflite::Interpreter* SelectionInterpreter();
     94 
     95   // Gets or creates and caches an interpreter for the classification model.
     96   tflite::Interpreter* ClassificationInterpreter();
     97 
     98  private:
     99   const ModelExecutor* selection_executor_;
    100   const ModelExecutor* classification_executor_;
    101 
    102   std::unique_ptr<tflite::Interpreter> selection_interpreter_;
    103   std::unique_ptr<tflite::Interpreter> classification_interpreter_;
    104 };
    105 
    106 // A text processing model that provides text classification, annotation,
    107 // selection suggestion for various types.
    108 // NOTE: This class is not thread-safe.
    109 class TextClassifier {
    110  public:
    111   static std::unique_ptr<TextClassifier> FromUnownedBuffer(
    112       const char* buffer, int size, const UniLib* unilib = nullptr);
    113   // Takes ownership of the mmap.
    114   static std::unique_ptr<TextClassifier> FromScopedMmap(
    115       std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr);
    116   static std::unique_ptr<TextClassifier> FromFileDescriptor(
    117       int fd, int offset, int size, const UniLib* unilib = nullptr);
    118   static std::unique_ptr<TextClassifier> FromFileDescriptor(
    119       int fd, const UniLib* unilib = nullptr);
    120   static std::unique_ptr<TextClassifier> FromPath(
    121       const std::string& path, const UniLib* unilib = nullptr);
    122 
    123   // Returns true if the model is ready for use.
    124   bool IsInitialized() { return initialized_; }
    125 
    126   // Runs inference for given a context and current selection (i.e. index
    127   // of the first and one past last selected characters (utf8 codepoint
    128   // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
    129   // beginning character and one past selection end character.
    130   // Returns the original click_indices if an error occurs.
    131   // NOTE: The selection indices are passed in and returned in terms of
    132   // UTF8 codepoints (not bytes).
    133   // Requires that the model is a smart selection model.
    134   CodepointSpan SuggestSelection(
    135       const std::string& context, CodepointSpan click_indices,
    136       const SelectionOptions& options = SelectionOptions::Default()) const;
    137 
    138   // Classifies the selected text given the context string.
    139   // Returns an empty result if an error occurs.
    140   std::vector<ClassificationResult> ClassifyText(
    141       const std::string& context, CodepointSpan selection_indices,
    142       const ClassificationOptions& options =
    143           ClassificationOptions::Default()) const;
    144 
    145   // Annotates given input text. The annotations are sorted by their position
    146   // in the context string and exclude spans classified as 'other'.
    147   std::vector<AnnotatedSpan> Annotate(
    148       const std::string& context,
    149       const AnnotationOptions& options = AnnotationOptions::Default()) const;
    150 
    151   // Exposes the feature processor for tests and evaluations.
    152   const FeatureProcessor* SelectionFeatureProcessorForTests() const;
    153   const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
    154 
    155   // Exposes the date time parser for tests and evaluations.
    156   const DatetimeParser* DatetimeParserForTests() const;
    157 
    158   // String collection names for various classes.
    159   static const std::string& kOtherCollection;
    160   static const std::string& kPhoneCollection;
    161   static const std::string& kAddressCollection;
    162   static const std::string& kDateCollection;
    163 
    164  protected:
    165   struct ScoredChunk {
    166     TokenSpan token_span;
    167     float score;
    168   };
    169 
    170   // Constructs and initializes text classifier from given model.
    171   // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
    172   TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
    173                  const UniLib* unilib)
    174       : model_(model),
    175         mmap_(std::move(*mmap)),
    176         owned_unilib_(nullptr),
    177         unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
    178     ValidateAndInitialize();
    179   }
    180 
    181   // Constructs, validates and initializes text classifier from given model.
    182   // Does not own the buffer that backs 'model'.
    183   explicit TextClassifier(const Model* model, const UniLib* unilib)
    184       : model_(model),
    185         owned_unilib_(nullptr),
    186         unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
    187     ValidateAndInitialize();
    188   }
    189 
    190   // Checks that model contains all required fields, and initializes internal
    191   // datastructures.
    192   void ValidateAndInitialize();
    193 
    194   // Initializes regular expressions for the regex model.
    195   bool InitializeRegexModel(ZlibDecompressor* decompressor);
    196 
    197   // Resolves conflicts in the list of candidates by removing some overlapping
    198   // ones. Returns indices of the surviving ones.
    199   // NOTE: Assumes that the candidates are sorted according to their position in
    200   // the span.
    201   bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
    202                         const std::string& context,
    203                         const std::vector<Token>& cached_tokens,
    204                         InterpreterManager* interpreter_manager,
    205                         std::vector<int>* result) const;
    206 
    207   // Resolves one conflict between candidates on indices 'start_index'
    208   // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
    209   // indices to 'chosen_indices'. Returns false if a problem arises.
    210   bool ResolveConflict(const std::string& context,
    211                        const std::vector<Token>& cached_tokens,
    212                        const std::vector<AnnotatedSpan>& candidates,
    213                        int start_index, int end_index,
    214                        InterpreterManager* interpreter_manager,
    215                        std::vector<int>* chosen_indices) const;
    216 
    217   // Gets selection candidates from the ML model.
    218   // Provides the tokens produced during tokenization of the context string for
    219   // reuse.
    220   bool ModelSuggestSelection(const UnicodeText& context_unicode,
    221                              CodepointSpan click_indices,
    222                              InterpreterManager* interpreter_manager,
    223                              std::vector<Token>* tokens,
    224                              std::vector<AnnotatedSpan>* result) const;
    225 
    226   // Classifies the selected text given the context string with the
    227   // classification model.
    228   // Returns true if no error occurred.
    229   bool ModelClassifyText(
    230       const std::string& context, const std::vector<Token>& cached_tokens,
    231       CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
    232       FeatureProcessor::EmbeddingCache* embedding_cache,
    233       std::vector<ClassificationResult>* classification_results) const;
    234 
    235   bool ModelClassifyText(
    236       const std::string& context, CodepointSpan selection_indices,
    237       InterpreterManager* interpreter_manager,
    238       FeatureProcessor::EmbeddingCache* embedding_cache,
    239       std::vector<ClassificationResult>* classification_results) const;
    240 
    241   // Returns a relative token span that represents how many tokens on the left
    242   // from the selection and right from the selection are needed for the
    243   // classifier input.
    244   TokenSpan ClassifyTextUpperBoundNeededTokens() const;
    245 
    246   // Classifies the selected text with the regular expressions models.
    247   // Returns true if any regular expression matched and the result was set.
    248   bool RegexClassifyText(const std::string& context,
    249                          CodepointSpan selection_indices,
    250                          ClassificationResult* classification_result) const;
    251 
    252   // Classifies the selected text with the date time model.
    253   // Returns true if there was a match and the result was set.
    254   bool DatetimeClassifyText(const std::string& context,
    255                             CodepointSpan selection_indices,
    256                             const ClassificationOptions& options,
    257                             ClassificationResult* classification_result) const;
    258 
    259   // Chunks given input text with the selection model and classifies the spans
    260   // with the classification model.
    261   // The annotations are sorted by their position in the context string and
    262   // exclude spans classified as 'other'.
    263   // Provides the tokens produced during tokenization of the context string for
    264   // reuse.
    265   bool ModelAnnotate(const std::string& context,
    266                      InterpreterManager* interpreter_manager,
    267                      std::vector<Token>* tokens,
    268                      std::vector<AnnotatedSpan>* result) const;
    269 
    270   // Groups the tokens into chunks. A chunk is a token span that should be the
    271   // suggested selection when any of its contained tokens is clicked. The chunks
    272   // are non-overlapping and are sorted by their position in the context string.
    273   // "num_tokens" is the total number of tokens available (as this method does
    274   // not need the actual vector of tokens).
    275   // "span_of_interest" is a span of all the tokens that could be clicked.
    276   // The resulting chunks all have to overlap with it and they cover this span
    277   // completely. The first and last chunk might extend beyond it.
    278   // The chunks vector is cleared before filling.
    279   bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
    280                   tflite::Interpreter* selection_interpreter,
    281                   const CachedFeatures& cached_features,
    282                   std::vector<TokenSpan>* chunks) const;
    283 
    284   // A helper method for ModelChunk(). It generates scored chunk candidates for
    285   // a click context model.
    286   // NOTE: The returned chunks can (and most likely do) overlap.
    287   bool ModelClickContextScoreChunks(
    288       int num_tokens, const TokenSpan& span_of_interest,
    289       const CachedFeatures& cached_features,
    290       tflite::Interpreter* selection_interpreter,
    291       std::vector<ScoredChunk>* scored_chunks) const;
    292 
    293   // A helper method for ModelChunk(). It generates scored chunk candidates for
    294   // a bounds-sensitive model.
    295   // NOTE: The returned chunks can (and most likely do) overlap.
    296   bool ModelBoundsSensitiveScoreChunks(
    297       int num_tokens, const TokenSpan& span_of_interest,
    298       const TokenSpan& inference_span, const CachedFeatures& cached_features,
    299       tflite::Interpreter* selection_interpreter,
    300       std::vector<ScoredChunk>* scored_chunks) const;
    301 
    302   // Produces chunks isolated by a set of regular expressions.
    303   bool RegexChunk(const UnicodeText& context_unicode,
    304                   const std::vector<int>& rules,
    305                   std::vector<AnnotatedSpan>* result) const;
    306 
    307   // Produces chunks from the datetime parser.
    308   bool DatetimeChunk(const UnicodeText& context_unicode,
    309                      int64 reference_time_ms_utc,
    310                      const std::string& reference_timezone,
    311                      const std::string& locales, ModeFlag mode,
    312                      std::vector<AnnotatedSpan>* result) const;
    313 
    314   // Returns whether a classification should be filtered.
    315   bool FilteredForAnnotation(const AnnotatedSpan& span) const;
    316   bool FilteredForClassification(
    317       const ClassificationResult& classification) const;
    318   bool FilteredForSelection(const AnnotatedSpan& span) const;
    319 
    320   const Model* model_;
    321 
    322   std::unique_ptr<const ModelExecutor> selection_executor_;
    323   std::unique_ptr<const ModelExecutor> classification_executor_;
    324   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
    325 
    326   std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
    327   std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
    328 
    329   std::unique_ptr<const DatetimeParser> datetime_parser_;
    330 
    331  private:
    332   struct CompiledRegexPattern {
    333     std::string collection_name;
    334     float target_classification_score;
    335     float priority_score;
    336     std::unique_ptr<UniLib::RegexPattern> pattern;
    337   };
    338 
    339   std::unique_ptr<ScopedMmap> mmap_;
    340   bool initialized_ = false;
    341   bool enabled_for_annotation_ = false;
    342   bool enabled_for_classification_ = false;
    343   bool enabled_for_selection_ = false;
    344   std::unordered_set<std::string> filtered_collections_annotation_;
    345   std::unordered_set<std::string> filtered_collections_classification_;
    346   std::unordered_set<std::string> filtered_collections_selection_;
    347 
    348   std::vector<CompiledRegexPattern> regex_patterns_;
    349   std::unordered_set<int> regex_approximate_match_pattern_ids_;
    350 
    351   // Indices into regex_patterns_ for the different modes.
    352   std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
    353       selection_regex_patterns_;
    354 
    355   std::unique_ptr<UniLib> owned_unilib_;
    356   const UniLib* unilib_;
    357 };
    358 
    359 namespace internal {
    360 
    361 // Helper function, which if the initial 'span' contains only white-spaces,
    362 // moves the selection to a single-codepoint selection on the left side
    363 // of this block of white-space.
    364 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
    365                                             const UnicodeText& context_unicode,
    366                                             const UniLib& unilib);
    367 
    368 // Copies tokens from 'cached_tokens' that are
    369 // 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
    370 // from the tokens that correspond to 'selection_indices'.
    371 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
    372                                     CodepointSpan selection_indices,
    373                                     TokenSpan tokens_around_selection_to_copy);
    374 }  // namespace internal
    375 
    376 // Interprets the buffer as a Model flatbuffer and returns it for reading.
    377 const Model* ViewModel(const void* buffer, int size);
    378 
    379 }  // namespace libtextclassifier2
    380 
    381 #endif  // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
    382