Home | History | Annotate | Download | only in native
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // This file contains the MulticlassPA class which implements a simple
     18 // linear multi-class classifier based on the multi-prototype version of
     19 // passive aggressive.
     20 
     21 #ifndef LEARNINGFW_MULTICLASS_PA_H_
     22 #define LEARNINGFW_MULTICLASS_PA_H_
     23 
     24 #include <vector>
     25 #include <cmath>
     26 
     27 const float kEpsilon = 1.0e-4;
     28 
     29 namespace learningfw {
     30 
     31 class MulticlassPA {
     32  public:
     33   MulticlassPA(int num_classes,
     34                int num_dimensions,
     35                float aggressiveness);
     36   virtual ~MulticlassPA();
     37 
     38   // Initialize all parameters to 0.0.
     39   void InitializeParameters();
     40 
     41   // Returns a random class that is different from the target class.
     42   int PickAClassExcept(int target);
     43 
     44   // Returns a random example.
     45   int PickAnExample(int num_examples);
     46 
     47   // Computes the score of a given input vector for a given parameter
     48   // vector, by computing the dot product between the two.
     49   float Score(const std::vector<float>& inputs,
     50               const std::vector<float>& parameters) const;
     51   float SparseScore(const std::vector<std::pair<int, float> >& inputs,
     52                     const std::vector<float>& parameters) const;
     53 
     54   // Returns the square of the L2 norm.
     55   float L2NormSquare(const std::vector<float>& inputs) const;
     56   float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const;
     57 
     58   // Verify if the given example is correctly classified with margin with
     59   // respect to a random class.  If not, then modifies the corresponding
     60   // parameters using passive-aggressive.
     61   virtual float TrainOneExample(const std::vector<float>& inputs, int target);
     62   virtual float SparseTrainOneExample(
     63       const std::vector<std::pair<int, float> >& inputs, int target);
     64 
     65   // Iteratively train the model for num_iterations on the given dataset.
     66   float Train(const std::vector<std::pair<std::vector<float>, int> >& data,
     67               int num_iterations);
     68   float SparseTrain(
     69       const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data,
     70       int num_iterations);
     71 
     72   // Returns the best class for a given input vector.
     73   virtual int GetClass(const std::vector<float>& inputs);
     74   virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs);
     75 
     76   // Computes the test error of a given test set on the current model.
     77   float Test(const std::vector<std::pair<std::vector<float>, int> >& data);
     78   float SparseTest(
     79       const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data);
     80 
     81   // A few accessors used by the sub-classes.
     82   inline float aggressiveness() const {
     83     return aggressiveness_;
     84   }
     85 
     86   inline std::vector<std::vector<float> >& parameters() {
     87     return parameters_;
     88   }
     89 
     90   inline std::vector<std::vector<float> >* mutable_parameters() {
     91     return &parameters_;
     92   }
     93 
     94   inline int num_classes() const {
     95     return num_classes_;
     96   }
     97 
     98   inline int num_dimensions() const {
     99     return num_dimensions_;
    100   }
    101 
    102  private:
    103   // Keeps the current parameter vector.
    104   std::vector<std::vector<float> > parameters_;
    105 
    106   // The number of classes of the problem.
    107   int num_classes_;
    108 
    109   // The number of dimensions of the input vectors.
    110   int num_dimensions_;
    111 
    112   // Controls how "aggressive" training should be.
    113   float aggressiveness_;
    114 
    115 };
    116 }  // namespace learningfw
    117 #endif  // LEARNINGFW_MULTICLASS_PA_H_
    118