Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
     18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
     19 
     20 #include <algorithm>
     21 #include <string>
     22 
     23 #include "common/float16.h"
     24 #include "common/task-context.h"
     25 #include "common/task-spec.pb.h"
     26 #include "util/base/logging.h"
     27 
     28 namespace libtextclassifier {
     29 namespace nlp_core {
     30 
     31 enum class QuantizationType { NONE = 0, UINT8 };
     32 
     33 // API for accessing parameters for a feed-forward neural network with
     34 // embeddings.
     35 //
     36 // Note: this API is closely related to embedding-network.proto.  The reason we
     37 // have a separate API is that the proto may not be the only way of packaging
     38 // these parameters.
     39 class EmbeddingNetworkParams {
     40  public:
     41   virtual ~EmbeddingNetworkParams() {}
     42 
     43   // **** High-level API.
     44 
     45   // Simple representation of a matrix.  This small struct that doesn't own any
     46   // resource intentionally supports copy / assign, to simplify our APIs.
     47   struct Matrix {
     48     // Number of rows.
     49     int rows;
     50 
     51     // Number of columns.
     52     int cols;
     53 
     54     QuantizationType quant_type;
     55 
     56     // Pointer to matrix elements, in row-major order
     57     // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
     58     const void *elements;
     59 
     60     // Quantization scales: one scale for each row.
     61     const float16 *quant_scales;
     62   };
     63 
     64   // Returns number of embedding spaces.
     65   int GetNumEmbeddingSpaces() const {
     66     if (embeddings_size() != embedding_num_features_size()) {
     67       TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size()
     68                     << " != " << embedding_num_features_size();
     69     }
     70     return std::max(0,
     71                     std::min(embeddings_size(), embedding_num_features_size()));
     72   }
     73 
     74   // Returns embedding matrix for the i-th embedding space.
     75   //
     76   // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
     77   // otherwise.
     78   Matrix GetEmbeddingMatrix(int i) const {
     79     TC_DCHECK(InRange(i, embeddings_size()));
     80     Matrix matrix;
     81     matrix.rows = embeddings_num_rows(i);
     82     matrix.cols = embeddings_num_cols(i);
     83     matrix.elements = embeddings_weights(i);
     84     matrix.quant_type = embeddings_quant_type(i);
     85     matrix.quant_scales = embeddings_quant_scales(i);
     86     return matrix;
     87   }
     88 
     89   // Returns number of features in i-th embedding space.
     90   //
     91   // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
     92   // otherwise.
     93   int GetNumFeaturesInEmbeddingSpace(int i) const {
     94     TC_DCHECK(InRange(i, embedding_num_features_size()));
     95     return std::max(0, embedding_num_features(i));
     96   }
     97 
     98   // Returns number of hidden layers in the neural network.  Each such layer has
     99   // weight matrix and a bias vector (a matrix with one column).
    100   int GetNumHiddenLayers() const {
    101     if (hidden_size() != hidden_bias_size()) {
    102       TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size()
    103                     << " != " << hidden_bias_size();
    104     }
    105     return std::max(0, std::min(hidden_size(), hidden_bias_size()));
    106   }
    107 
    108   // Returns weight matrix for i-th hidden layer.
    109   //
    110   // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
    111   // otherwise.
    112   Matrix GetHiddenLayerMatrix(int i) const {
    113     TC_DCHECK(InRange(i, hidden_size()));
    114     Matrix matrix;
    115     matrix.rows = hidden_num_rows(i);
    116     matrix.cols = hidden_num_cols(i);
    117 
    118     // Quantization not supported here.
    119     matrix.quant_type = QuantizationType::NONE;
    120     matrix.elements = hidden_weights(i);
    121     return matrix;
    122   }
    123 
    124   // Returns bias matrix for i-th hidden layer.  Technically a Matrix, but we
    125   // expect it to be a vector (i.e., num cols is 1).
    126   //
    127   // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
    128   // otherwise.
    129   Matrix GetHiddenLayerBias(int i) const {
    130     TC_DCHECK(InRange(i, hidden_bias_size()));
    131     Matrix matrix;
    132     matrix.rows = hidden_bias_num_rows(i);
    133     matrix.cols = hidden_bias_num_cols(i);
    134 
    135     // Quantization not supported here.
    136     matrix.quant_type = QuantizationType::NONE;
    137     matrix.elements = hidden_bias_weights(i);
    138     return matrix;
    139   }
    140 
    141   // Returns true if a softmax layer exists.
    142   bool HasSoftmaxLayer() const {
    143     if (softmax_size() != softmax_bias_size()) {
    144       TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size()
    145                     << " != " << softmax_bias_size();
    146     }
    147     return (softmax_size() == 1) && (softmax_bias_size() == 1);
    148   }
    149 
    150   // Returns weight matrix for the softmax layer.
    151   //
    152   // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
    153   // behavior otherwise.
    154   Matrix GetSoftmaxMatrix() const {
    155     TC_DCHECK(softmax_size() == 1);
    156     Matrix matrix;
    157     matrix.rows = softmax_num_rows(0);
    158     matrix.cols = softmax_num_cols(0);
    159 
    160     // Quantization not supported here.
    161     matrix.quant_type = QuantizationType::NONE;
    162     matrix.elements = softmax_weights(0);
    163     return matrix;
    164   }
    165 
    166   // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
    167   // to be a row/column vector (i.e., num cols is 1).
    168   //
    169   // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
    170   // behavior otherwise.
    171   Matrix GetSoftmaxBias() const {
    172     TC_DCHECK(softmax_bias_size() == 1);
    173     Matrix matrix;
    174     matrix.rows = softmax_bias_num_rows(0);
    175     matrix.cols = softmax_bias_num_cols(0);
    176 
    177     // Quantization not supported here.
    178     matrix.quant_type = QuantizationType::NONE;
    179     matrix.elements = softmax_bias_weights(0);
    180     return matrix;
    181   }
    182 
    183   // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
    184   // true on success, false on error.
    185   virtual bool UpdateTaskContextParameters(TaskContext *task_context) {
    186     const TaskSpec *task_spec = GetTaskSpec();
    187     if (task_spec == nullptr) {
    188       TC_LOG(ERROR) << "Unable to get TaskSpec";
    189       return false;
    190     }
    191     for (const TaskSpec::Parameter &parameter : task_spec->parameter()) {
    192       task_context->SetParameter(parameter.name(), parameter.value());
    193     }
    194     return true;
    195   }
    196 
    197   // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related
    198   // parameters.  Returns nullptr in case of problems.  Ownership with the
    199   // returned pointer is *not* transfered to the caller.
    200   virtual const TaskSpec *GetTaskSpec() {
    201     TC_LOG(ERROR) << "Not implemented";
    202     return nullptr;
    203   }
    204 
    205  protected:
    206   // **** Low-level API.
    207   //
    208   // * Most low-level API methods are documented by giving an equivalent
    209   //   function call on proto, the original proto (of type
    210   //   EmbeddingNetworkProto) which was used to generate the C++ code.
    211   //
    212   // * To simplify our generation code, optional proto fields of message type
    213   //   are treated as repeated fields with 0 or 1 instances.  As such, we have
    214   //   *_size() methods for such optional fields: they return 0 or 1.
    215   //
    216   // * "transpose(M)" denotes the transpose of a matrix M.
    217   //
    218   // * Behavior is undefined when trying to retrieve a piece of data that does
    219   //   not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2.
    220 
    221   // ** Access methods for repeated MatrixParams embeddings.
    222   //
    223   // Returns proto.embeddings_size().
    224   virtual int embeddings_size() const = 0;
    225 
    226   // Returns number of rows of transpose(proto.embeddings(i)).
    227   virtual int embeddings_num_rows(int i) const = 0;
    228 
    229   // Returns number of columns of transpose(proto.embeddings(i)).
    230   virtual int embeddings_num_cols(int i) const = 0;
    231 
    232   // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
    233   // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
    234   // for quantized embeddings, this returns a pointer to uint8.
    235   virtual const void *embeddings_weights(int i) const = 0;
    236 
    237   virtual QuantizationType embeddings_quant_type(int i) const {
    238     return QuantizationType::NONE;
    239   }
    240 
    241   virtual const float16 *embeddings_quant_scales(int i) const {
    242     return nullptr;
    243   }
    244 
    245   // ** Access methods for repeated MatrixParams hidden.
    246   //
    247   // Returns embedding_network_proto.hidden_size().
    248   virtual int hidden_size() const = 0;
    249 
    250   // Returns embedding_network_proto.hidden(i).rows().
    251   virtual int hidden_num_rows(int i) const = 0;
    252 
    253   // Returns embedding_network_proto.hidden(i).rows().
    254   virtual int hidden_num_cols(int i) const = 0;
    255 
    256   // Returns pointer to beginning of array of floats with all values from
    257   // embedding_network_proto.hidden(i).
    258   virtual const void *hidden_weights(int i) const = 0;
    259 
    260   // ** Access methods for repeated MatrixParams hidden_bias.
    261   //
    262   // Returns proto.hidden_bias_size().
    263   virtual int hidden_bias_size() const = 0;
    264 
    265   // Returns number of rows of proto.hidden_bias(i).
    266   virtual int hidden_bias_num_rows(int i) const = 0;
    267 
    268   // Returns number of columns of proto.hidden_bias(i).
    269   virtual int hidden_bias_num_cols(int i) const = 0;
    270 
    271   // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
    272   virtual const void *hidden_bias_weights(int i) const = 0;
    273 
    274   // ** Access methods for optional MatrixParams softmax.
    275   //
    276   // Returns 1 if proto has optional field softmax, 0 otherwise.
    277   virtual int softmax_size() const = 0;
    278 
    279   // Returns number of rows of transpose(proto.softmax()).
    280   virtual int softmax_num_rows(int i) const = 0;
    281 
    282   // Returns number of columns of transpose(proto.softmax()).
    283   virtual int softmax_num_cols(int i) const = 0;
    284 
    285   // Returns pointer to elements of transpose(proto.softmax()), in row-major
    286   // order.
    287   virtual const void *softmax_weights(int i) const = 0;
    288 
    289   // ** Access methods for optional MatrixParams softmax_bias.
    290   //
    291   // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
    292   virtual int softmax_bias_size() const = 0;
    293 
    294   // Returns number of rows of proto.softmax_bias().
    295   virtual int softmax_bias_num_rows(int i) const = 0;
    296 
    297   // Returns number of columns of proto.softmax_bias().
    298   virtual int softmax_bias_num_cols(int i) const = 0;
    299 
    300   // Returns pointer to elements of proto.softmax_bias(), in row-major order.
    301   virtual const void *softmax_bias_weights(int i) const = 0;
    302 
    303   // ** Access methods for repeated int32 embedding_num_features.
    304   //
    305   // Returns proto.embedding_num_features_size().
    306   virtual int embedding_num_features_size() const = 0;
    307 
    308   // Returns proto.embedding_num_features(i).
    309   virtual int embedding_num_features(int i) const = 0;
    310 
    311   // Returns true if and only if index is in range [0, size).  Log an error
    312   // message otherwise.
    313   static bool InRange(int index, int size) {
    314     if ((index < 0) || (index >= size)) {
    315       TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")";
    316       return false;
    317     }
    318     return true;
    319   }
    320 };  // class EmbeddingNetworkParams
    321 
    322 }  // namespace nlp_core
    323 }  // namespace libtextclassifier
    324 
    325 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
    326