Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Inference code for the feed-forward text classification models.
     18 
     19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
     20 #define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
     21 
     22 #include <memory>
     23 #include <set>
     24 #include <string>
     25 
     26 #include "base.h"
     27 #include "common/embedding-network.h"
     28 #include "common/feature-extractor.h"
     29 #include "common/memory_image/embedding-network-params-from-image.h"
     30 #include "common/mmap.h"
     31 #include "smartselect/feature-processor.h"
     32 #include "smartselect/model-params.h"
     33 #include "smartselect/text-classification-model.pb.h"
     34 #include "smartselect/types.h"
     35 
     36 namespace libtextclassifier {
     37 
     38 // SmartSelection/Sharing feed-forward model.
     39 class TextClassificationModel {
     40  public:
     41   // Loads TextClassificationModel from given file given by an int
     42   // file descriptor.
     43   explicit TextClassificationModel(int fd);
     44 
     45   // Bit flags for the input selection.
     46   enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
     47 
     48   // Runs inference for given a context and current selection (i.e. index
     49   // of the first and one past last selected characters (utf8 codepoint
     50   // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
     51   // beginning character and one past selection end character.
     52   // Returns the original click_indices if an error occurs.
     53   // NOTE: The selection indices are passed in and returned in terms of
     54   // UTF8 codepoints (not bytes).
     55   // Requires that the model is a smart selection model.
     56   CodepointSpan SuggestSelection(const std::string& context,
     57                                  CodepointSpan click_indices) const;
     58 
     59   // Classifies the selected text given the context string.
     60   // Requires that the model is a smart sharing model.
     61   // Returns an empty result if an error occurs.
     62   std::vector<std::pair<std::string, float>> ClassifyText(
     63       const std::string& context, CodepointSpan click_indices,
     64       int input_flags = 0) const;
     65 
     66  protected:
     67   // Removes punctuation from the beginning and end of the selection and returns
     68   // the new selection span.
     69   CodepointSpan StripPunctuation(CodepointSpan selection,
     70                                  const std::string& context) const;
     71 
     72   // During evaluation we need access to the feature processor.
     73   FeatureProcessor* SelectionFeatureProcessor() const {
     74     return selection_feature_processor_.get();
     75   }
     76 
     77   // Collection name when url hint is accepted.
     78   const std::string kUrlHintCollection = "url";
     79 
     80   // Collection name when email hint is accepted.
     81   const std::string kEmailHintCollection = "email";
     82 
     83   // Collection name for other.
     84   const std::string kOtherCollection = "other";
     85 
     86   // Collection name for phone.
     87   const std::string kPhoneCollection = "phone";
     88 
     89   SelectionModelOptions selection_options_;
     90   SharingModelOptions sharing_options_;
     91 
     92  private:
     93   bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
     94 
     95   nlp_core::EmbeddingNetwork::Vector InferInternal(
     96       const std::string& context, CodepointSpan span,
     97       const FeatureProcessor& feature_processor,
     98       const nlp_core::EmbeddingNetwork& network,
     99       const FeatureVectorFn& feature_vector_fn,
    100       std::vector<CodepointSpan>* selection_label_spans) const;
    101 
    102   // Returns a selection suggestion with a score.
    103   std::pair<CodepointSpan, float> SuggestSelectionInternal(
    104       const std::string& context, CodepointSpan click_indices) const;
    105 
    106   // Returns a selection suggestion and makes sure it's symmetric. Internally
    107   // runs several times SuggestSelectionInternal.
    108   CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
    109                                             CodepointSpan click_indices) const;
    110 
    111   bool initialized_;
    112   nlp_core::ScopedMmap mmap_;
    113   std::unique_ptr<ModelParams> selection_params_;
    114   std::unique_ptr<FeatureProcessor> selection_feature_processor_;
    115   std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
    116   FeatureVectorFn selection_feature_fn_;
    117   std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
    118   std::unique_ptr<ModelParams> sharing_params_;
    119   std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
    120   FeatureVectorFn sharing_feature_fn_;
    121 
    122   std::set<int> punctuation_to_strip_;
    123 };
    124 
    125 // Parses the merged image given as a file descriptor, and reads
    126 // the ModelOptions proto from the selection model.
    127 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
    128 
    129 }  // namespace libtextclassifier
    130 
    131 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
    132