Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
     18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
     19 
     20 #include <string>
     21 
     22 #include "lang_id/common/fel/task-context.h"
     23 #include "lang_id/common/lite_base/float16.h"
     24 #include "lang_id/common/lite_base/logging.h"
     25 
     26 namespace libtextclassifier3 {
     27 
     28 enum class QuantizationType {
     29   NONE = 0,
     30 
     31   // Quantization to 8 bit unsigned ints.
     32   UINT8,
     33 
     34   // Quantization to 4 bit unsigned ints.
     35   UINT4,
     36 
     37   // Quantization to 16 bit floats, the type defined in
     38   // lang_id/common/float16.h
     39   FLOAT16,
     40 
     41   // NOTE: for backward compatibility, if you add a new value to this enum, add
     42   // it *at the end*, such that you do not change the integer values of the
     43   // existing enum values.
     44 };
     45 
     46 // Converts "UINT8" -> QuantizationType::UINT8, and so on.
     47 QuantizationType ParseQuantizationType(const string &s);
     48 
     49 // API for accessing parameters for a feed-forward neural network with
     50 // embeddings.
     51 //
     52 //
     53 // In fact, we provide two APIs: a high-level (and highly-recommented) API, with
     54 // methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
     55 // low-level API, using C-style names (e.g., softmax_num_cols()).
     56 //
     57 // Note: the API below is meant to allow the inference code (the class
     58 // libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
     59 // for transposing any matrix (which would require extra overhead on mobile
     60 // devices).  Hence, as indicated by the comments for the API methods, some of
     61 // the matrices below are the transposes of the corresponding matrices from the
     62 // original proto.
     63 class EmbeddingNetworkParams {
     64  public:
     65   virtual ~EmbeddingNetworkParams() {}
     66 
     67   // Returns true if these params are valid.  False otherwise (e.g., if the
     68   // underlying data is corrupted).  If is_valid() returns false, clients should
     69   // not call any other method on that instance of EmbeddingNetworkParams.  If
     70   // is_valid() returns true, then calls to the API methods below should not
     71   // crash *if they are called with index parameters in bounds*.  E.g., if
     72   // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
     73   // should not crash.
     74   virtual bool is_valid() const = 0;
     75 
     76   // **** High-level API.
     77 
     78   // Simple representation of a matrix.  This small struct that doesn't own any
     79   // resource intentionally supports copy / assign, to simplify our APIs.
     80   struct Matrix {
     81     // Number of rows.
     82     int rows = 0;
     83 
     84     // Number of columns.
     85     int cols = 0;
     86 
     87     QuantizationType quant_type = QuantizationType::NONE;
     88 
     89     // Pointer to matrix elements, in row-major order
     90     // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
     91     const void *elements = nullptr;
     92 
     93     // Quantization scales: one scale for each row.
     94     const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
     95   };
     96 
     97   // Returns i-th embedding matrix.  Crashes on out of bounds indices.
     98   //
     99   // This is the transpose of the corresponding matrix from the original proto.
    100   Matrix GetEmbeddingMatrix(int i) const {
    101     CheckIndex(i, embeddings_size(), "embedding matrix");
    102     Matrix matrix;
    103     matrix.rows = embeddings_num_rows(i);
    104     matrix.cols = embeddings_num_cols(i);
    105     matrix.elements = embeddings_weights(i);
    106     matrix.quant_type = embeddings_quant_type(i);
    107     matrix.quant_scales = embeddings_quant_scales(i);
    108     return matrix;
    109   }
    110 
    111   // Returns weight matrix for i-th hidden layer.  Crashes on out of bounds
    112   // indices.
    113   //
    114   // This is the transpose of the corresponding matrix from the original proto.
    115   Matrix GetHiddenLayerMatrix(int i) const {
    116     CheckIndex(i, hidden_size(), "hidden layer");
    117     Matrix matrix;
    118     matrix.rows = hidden_num_rows(i);
    119     matrix.cols = hidden_num_cols(i);
    120 
    121     // Quantization not supported here.
    122     matrix.quant_type = hidden_weights_quant_type(i);
    123     matrix.elements = hidden_weights(i);
    124     return matrix;
    125   }
    126 
    127   // Returns bias for i-th hidden layer.  Technically a Matrix, but we expect it
    128   // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
    129   // don't CHECK for that: we just provide access to underlying data.  Crashes
    130   // on out of bounds indices.
    131   Matrix GetHiddenLayerBias(int i) const {
    132     CheckIndex(i, hidden_bias_size(), "hidden layer bias");
    133     Matrix matrix;
    134     matrix.rows = hidden_bias_num_rows(i);
    135     matrix.cols = hidden_bias_num_cols(i);
    136 
    137     // Quantization not supported here.
    138     matrix.quant_type = QuantizationType::NONE;
    139     matrix.elements = hidden_bias_weights(i);
    140     return matrix;
    141   }
    142 
    143   // Returns true if a softmax layer exists.
    144   bool HasSoftmax() const {
    145     return softmax_size() == 1;
    146   }
    147 
    148   // Returns weight matrix for the softmax layer.  Note: should be called only
    149   // if HasSoftmax() is true.
    150   //
    151   // This is the transpose of the corresponding matrix from the original proto.
    152   Matrix GetSoftmaxMatrix() const {
    153     SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
    154     Matrix matrix;
    155     matrix.rows = softmax_num_rows(0);
    156     matrix.cols = softmax_num_cols(0);
    157 
    158     // Quantization not supported here.
    159     matrix.quant_type = softmax_weights_quant_type(0);
    160     matrix.elements = softmax_weights(0);
    161     return matrix;
    162   }
    163 
    164   // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
    165   // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
    166   // don't CHECK for that: we just provide access to underlying data.
    167   Matrix GetSoftmaxBias() const {
    168     SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
    169     Matrix matrix;
    170     matrix.rows = softmax_bias_num_rows(0);
    171     matrix.cols = softmax_bias_num_cols(0);
    172 
    173     // Quantization not supported here.
    174     matrix.quant_type = QuantizationType::NONE;
    175     matrix.elements = softmax_bias_weights(0);
    176     return matrix;
    177   }
    178 
    179   // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
    180   // true on success, false on error.
    181   virtual bool UpdateTaskContextParameters(
    182       mobile::TaskContext *task_context) = 0;
    183 
    184   // **** Low-level API.
    185   //
    186   // * Most low-level API methods are documented by giving an equivalent
    187   //   function call on proto, the original proto (of type
    188   //   EmbeddingNetworkProto) which was used to generate the C++ code.
    189   //
    190   // * To simplify our generation code, optional proto fields of message type
    191   //   are treated as repeated fields with 0 or 1 instances.  As such, we have
    192   //   *_size() methods for such optional fields: they return 0 or 1.
    193   //
    194   // * "transpose(M)" denotes the transpose of a matrix M.
    195 
    196   // ** Access methods for repeated MatrixParams embeddings.
    197   //
    198   // Returns proto.embeddings_size().
    199   virtual int embeddings_size() const = 0;
    200 
    201   // Returns number of rows of transpose(proto.embeddings(i)).
    202   virtual int embeddings_num_rows(int i) const = 0;
    203 
    204   // Returns number of columns of transpose(proto.embeddings(i)).
    205   virtual int embeddings_num_cols(int i) const = 0;
    206 
    207   // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
    208   // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
    209   // for quantized embeddings, this returns a pointer to uint8.
    210   virtual const void *embeddings_weights(int i) const = 0;
    211 
    212   virtual QuantizationType embeddings_quant_type(int i) const {
    213     return QuantizationType::NONE;
    214   }
    215 
    216   virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
    217       int i) const {
    218     return nullptr;
    219   }
    220 
    221   // ** Access methods for repeated MatrixParams hidden.
    222   //
    223   // Returns embedding_network_proto.hidden_size().
    224   virtual int hidden_size() const = 0;
    225 
    226   // Returns embedding_network_proto.hidden(i).rows().
    227   virtual int hidden_num_rows(int i) const = 0;
    228 
    229   // Returns embedding_network_proto.hidden(i).rows().
    230   virtual int hidden_num_cols(int i) const = 0;
    231 
    232   // Returns quantization mode for the weights of the i-th hidden layer.
    233   virtual QuantizationType hidden_weights_quant_type(int i) const {
    234     return QuantizationType::NONE;
    235   }
    236 
    237   // Returns pointer to beginning of array of floats with all values from
    238   // embedding_network_proto.hidden(i).
    239   virtual const void *hidden_weights(int i) const = 0;
    240 
    241   // ** Access methods for repeated MatrixParams hidden_bias.
    242   //
    243   // Returns proto.hidden_bias_size().
    244   virtual int hidden_bias_size() const = 0;
    245 
    246   // Returns number of rows of proto.hidden_bias(i).
    247   virtual int hidden_bias_num_rows(int i) const = 0;
    248 
    249   // Returns number of columns of proto.hidden_bias(i).
    250   virtual int hidden_bias_num_cols(int i) const = 0;
    251 
    252   // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
    253   virtual const void *hidden_bias_weights(int i) const = 0;
    254 
    255   // ** Access methods for optional MatrixParams softmax.
    256   //
    257   // Returns 1 if proto has optional field softmax, 0 otherwise.
    258   virtual int softmax_size() const = 0;
    259 
    260   // Returns number of rows of transpose(proto.softmax()).
    261   virtual int softmax_num_rows(int i) const = 0;
    262 
    263   // Returns number of columns of transpose(proto.softmax()).
    264   virtual int softmax_num_cols(int i) const = 0;
    265 
    266   // Returns quantization mode for the softmax weights.
    267   virtual QuantizationType softmax_weights_quant_type(int i) const {
    268     return QuantizationType::NONE;
    269   }
    270 
    271   // Returns pointer to elements of transpose(proto.softmax()), in row-major
    272   // order.
    273   virtual const void *softmax_weights(int i) const = 0;
    274 
    275   // ** Access methods for optional MatrixParams softmax_bias.
    276   //
    277   // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
    278   virtual int softmax_bias_size() const = 0;
    279 
    280   // Returns number of rows of proto.softmax_bias().
    281   virtual int softmax_bias_num_rows(int i) const = 0;
    282 
    283   // Returns number of columns of proto.softmax_bias().
    284   virtual int softmax_bias_num_cols(int i) const = 0;
    285 
    286   // Returns pointer to elements of proto.softmax_bias(), in row-major order.
    287   virtual const void *softmax_bias_weights(int i) const = 0;
    288 
    289   // ** Access methods for repeated int32 embedding_num_features.
    290   //
    291   // Returns proto.embedding_num_features_size().
    292   virtual int embedding_num_features_size() const = 0;
    293 
    294   // Returns proto.embedding_num_features(i).
    295   virtual int embedding_num_features(int i) const = 0;
    296 
    297   // ** Access methods for is_precomputed
    298   //
    299   // Returns proto.has_is_precomputed().
    300   virtual bool has_is_precomputed() const = 0;
    301 
    302   // Returns proto.is_precomputed().
    303   virtual bool is_precomputed() const = 0;
    304 
    305  protected:
    306   void CheckIndex(int index, int size, const string &description) const {
    307     SAFTM_CHECK_GE(index, 0)
    308         << "Out-of-range index for " << description << ": " << index;
    309     SAFTM_CHECK_LT(index, size)
    310         << "Out-of-range index for " << description << ": " << index;
    311   }
    312 };  // class EmbeddingNetworkParams
    313 
    314 }  // namespace nlp_saft
    315 
    316 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
    317