Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Model parameter loading.
     18 
     19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
     20 #define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
     21 
     22 #include "common/embedding-network.h"
     23 #include "common/memory_image/embedding-network-params-from-image.h"
     24 #include "smartselect/text-classification-model.pb.h"
     25 
     26 namespace libtextclassifier {
     27 
     28 class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage {
     29  public:
     30   EmbeddingParams(const void* start, uint64 num_bytes, int context_size)
     31       : EmbeddingNetworkParamsFromImage(start, num_bytes),
     32         context_size_(context_size) {}
     33 
     34   int embeddings_size() const override { return context_size_ * 2 + 1; }
     35 
     36   int embedding_num_features_size() const override {
     37     return context_size_ * 2 + 1;
     38   }
     39 
     40   int embedding_num_features(int i) const override { return 1; }
     41 
     42   int embeddings_num_rows(int i) const override {
     43     return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
     44   };
     45 
     46   int embeddings_num_cols(int i) const override {
     47     return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
     48   };
     49 
     50   const void* embeddings_weights(int i) const override {
     51     return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
     52   };
     53 
     54   nlp_core::QuantizationType embeddings_quant_type(int i) const override {
     55     return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
     56   }
     57 
     58   const nlp_core::float16* embeddings_quant_scales(int i) const override {
     59     return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
     60   }
     61 
     62  private:
     63   int context_size_;
     64 };
     65 
     66 // Loads and holds the parameters of the inference network.
     67 //
     68 // This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
     69 // because we only have one embedding matrix for all positions of context,
     70 // whereas the original class would have a separate one for each.
     71 class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
     72  public:
     73   const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
     74     return feature_processor_options_;
     75   }
     76 
     77   const SelectionModelOptions& GetSelectionModelOptions() const {
     78     return selection_options_;
     79   }
     80 
     81   const SharingModelOptions& GetSharingModelOptions() const {
     82     return sharing_options_;
     83   }
     84 
     85   std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const {
     86     return embedding_params_;
     87   }
     88 
     89  protected:
     90   int embeddings_size() const override {
     91     return embedding_params_->embeddings_size();
     92   }
     93 
     94   int embedding_num_features_size() const override {
     95     return embedding_params_->embedding_num_features_size();
     96   }
     97 
     98   int embedding_num_features(int i) const override {
     99     return embedding_params_->embedding_num_features(i);
    100   }
    101 
    102   int embeddings_num_rows(int i) const override {
    103     return embedding_params_->embeddings_num_rows(i);
    104   };
    105 
    106   int embeddings_num_cols(int i) const override {
    107     return embedding_params_->embeddings_num_cols(i);
    108   };
    109 
    110   const void* embeddings_weights(int i) const override {
    111     return embedding_params_->embeddings_weights(i);
    112   };
    113 
    114   nlp_core::QuantizationType embeddings_quant_type(int i) const override {
    115     return embedding_params_->embeddings_quant_type(i);
    116   }
    117 
    118   const nlp_core::float16* embeddings_quant_scales(int i) const override {
    119     return embedding_params_->embeddings_quant_scales(i);
    120   }
    121 
    122  private:
    123   friend ModelParams* ModelParamsBuilder(
    124       const void* start, uint64 num_bytes,
    125       std::shared_ptr<EmbeddingParams> external_embedding_params);
    126 
    127   ModelParams(const void* start, uint64 num_bytes,
    128               std::shared_ptr<EmbeddingParams> embedding_params,
    129               const SelectionModelOptions& selection_options,
    130               const SharingModelOptions& sharing_options,
    131               const FeatureProcessorOptions& feature_processor_options)
    132       : EmbeddingNetworkParamsFromImage(start, num_bytes),
    133         selection_options_(selection_options),
    134         sharing_options_(sharing_options),
    135         feature_processor_options_(feature_processor_options),
    136         context_size_(feature_processor_options_.context_size()),
    137         embedding_params_(std::move(embedding_params)) {}
    138 
    139   SelectionModelOptions selection_options_;
    140   SharingModelOptions sharing_options_;
    141   FeatureProcessorOptions feature_processor_options_;
    142   int context_size_;
    143   std::shared_ptr<EmbeddingParams> embedding_params_;
    144 };
    145 
    146 ModelParams* ModelParamsBuilder(
    147     const void* start, uint64 num_bytes,
    148     std::shared_ptr<EmbeddingParams> external_embedding_params);
    149 
    150 }  // namespace libtextclassifier
    151 
    152 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_
    153