Home | History | Annotate | Download | only in native
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 //
     18 // This file contains the MulticlassPA class which implements a simple
     19 // linear multi-class classifier based on the multi-prototype version of
     20 // passive aggressive.
     21 
     22 #include "native/multiclass_pa.h"
     23 
     24 #include <stdlib.h>
     25 
     26 using std::vector;
     27 using std::pair;
     28 
     29 namespace learningfw {
     30 
     31 float RandFloat() {
     32   return static_cast<float>(rand()) / RAND_MAX;
     33 }
     34 
     35 MulticlassPA::MulticlassPA(int num_classes,
     36                            int num_dimensions,
     37                            float aggressiveness)
     38     : num_classes_(num_classes),
     39       num_dimensions_(num_dimensions),
     40       aggressiveness_(aggressiveness) {
     41   InitializeParameters();
     42 }
     43 
     44 MulticlassPA::~MulticlassPA() {
     45 }
     46 
     47 void MulticlassPA::InitializeParameters() {
     48   parameters_.resize(num_classes_);
     49   for (int i = 0; i < num_classes_; ++i) {
     50     parameters_[i].resize(num_dimensions_);
     51     for (int j = 0; j < num_dimensions_; ++j) {
     52       parameters_[i][j] = 0.0;
     53     }
     54   }
     55 }
     56 
     57 int MulticlassPA::PickAClassExcept(int target) {
     58   int picked;
     59   do {
     60     picked = static_cast<int>(RandFloat() * num_classes_);
     61     //    picked = static_cast<int>(random_.RandFloat() * num_classes_);
     62   } while (target == picked);
     63   return picked;
     64 }
     65 
     66 int MulticlassPA::PickAnExample(int num_examples) {
     67   return static_cast<int>(RandFloat() * num_examples);
     68 }
     69 
     70 float MulticlassPA::Score(const vector<float>& inputs,
     71                           const vector<float>& parameters) const {
     72   // CHECK_EQ(inputs.size(), parameters.size());
     73   float result = 0.0;
     74   for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
     75     result += inputs[i] * parameters[i];
     76   }
     77   return result;
     78 }
     79 
     80 float MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs,
     81                                 const vector<float>& parameters) const {
     82   float result = 0.0;
     83   for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
     84     //DCHECK_GE(inputs[i].first, 0);
     85     //DCHECK_LT(inputs[i].first, parameters.size());
     86     result += inputs[i].second * parameters[inputs[i].first];
     87   }
     88   return result;
     89 }
     90 
     91 float MulticlassPA::L2NormSquare(const vector<float>& inputs) const {
     92   float norm = 0;
     93   for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
     94     norm += inputs[i] * inputs[i];
     95   }
     96   return norm;
     97 }
     98 
     99 float MulticlassPA::SparseL2NormSquare(
    100     const vector<pair<int, float> >& inputs) const {
    101   float norm = 0;
    102   for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
    103     norm += inputs[i].second * inputs[i].second;
    104   }
    105   return norm;
    106 }
    107 
    108 float MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) {
    109   //CHECK_GE(target, 0);
    110   //CHECK_LT(target, num_classes_);
    111   float target_class_score = Score(inputs, parameters_[target]);
    112   //  VLOG(1) << "target class " << target << " score " << target_class_score;
    113   int other_class = PickAClassExcept(target);
    114   float other_class_score = Score(inputs, parameters_[other_class]);
    115   //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
    116   float loss = 1.0 - target_class_score + other_class_score;
    117   if (loss > 0.0) {
    118     // Compute the learning rate according to PA-I.
    119     float twice_norm_square = L2NormSquare(inputs) * 2.0;
    120     if (twice_norm_square == 0.0) {
    121       twice_norm_square = kEpsilon;
    122     }
    123     float rate = loss / twice_norm_square;
    124     if (rate > aggressiveness_) {
    125       rate = aggressiveness_;
    126     }
    127     //    VLOG(1) << "loss = " << loss << " rate = " << rate;
    128     // Modify the parameter vectors of the correct and wrong classes
    129     for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
    130       // First modify the parameter value of the correct class
    131       parameters_[target][i] += rate * inputs[i];
    132       // Then modify the parameter value of the wrong class
    133       parameters_[other_class][i] -= rate * inputs[i];
    134     }
    135     return loss;
    136   }
    137   return 0.0;
    138 }
    139 
    140 float MulticlassPA::SparseTrainOneExample(
    141     const vector<pair<int, float> >& inputs, int target) {
    142   // CHECK_GE(target, 0);
    143   // CHECK_LT(target, num_classes_);
    144   float target_class_score = SparseScore(inputs, parameters_[target]);
    145   //  VLOG(1) << "target class " << target << " score " << target_class_score;
    146   int other_class = PickAClassExcept(target);
    147   float other_class_score = SparseScore(inputs, parameters_[other_class]);
    148   //  VLOG(1) << "other class " << other_class << " score " << other_class_score;
    149   float loss = 1.0 - target_class_score + other_class_score;
    150   if (loss > 0.0) {
    151     // Compute the learning rate according to PA-I.
    152     float twice_norm_square = SparseL2NormSquare(inputs) * 2.0;
    153     if (twice_norm_square == 0.0) {
    154       twice_norm_square = kEpsilon;
    155     }
    156     float rate = loss / twice_norm_square;
    157     if (rate > aggressiveness_) {
    158       rate = aggressiveness_;
    159     }
    160     //    VLOG(1) << "loss = " << loss << " rate = " << rate;
    161     // Modify the parameter vectors of the correct and wrong classes
    162     for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
    163       // First modify the parameter value of the correct class
    164       parameters_[target][inputs[i].first] += rate * inputs[i].second;
    165       // Then modify the parameter value of the wrong class
    166       parameters_[other_class][inputs[i].first] -= rate * inputs[i].second;
    167     }
    168     return loss;
    169   }
    170   return 0.0;
    171 }
    172 
    173 float MulticlassPA::Train(const vector<pair<vector<float>, int> >& data,
    174                           int num_iterations) {
    175   int num_examples = data.size();
    176   float total_loss = 0.0;
    177   for (int t = 0; t < num_iterations; ++t) {
    178     int index = PickAnExample(num_examples);
    179     float loss_t = TrainOneExample(data[index].first, data[index].second);
    180     total_loss += loss_t;
    181   }
    182   return total_loss / static_cast<float>(num_iterations);
    183 }
    184 
    185 float MulticlassPA::SparseTrain(
    186     const vector<pair<vector<pair<int, float> >, int> >& data,
    187     int num_iterations) {
    188   int num_examples = data.size();
    189   float total_loss = 0.0;
    190   for (int t = 0; t < num_iterations; ++t) {
    191     int index = PickAnExample(num_examples);
    192     float loss_t = SparseTrainOneExample(data[index].first, data[index].second);
    193     total_loss += loss_t;
    194   }
    195   return total_loss / static_cast<float>(num_iterations);
    196 }
    197 
    198 int MulticlassPA::GetClass(const vector<float>& inputs) {
    199   int best_class = -1;
    200   float best_score = -10000.0;
    201   // float best_score = -MathLimits<float>::kMax;
    202   for (int i = 0; i < num_classes_; ++i) {
    203     float score_i = Score(inputs, parameters_[i]);
    204     if (score_i > best_score) {
    205       best_score = score_i;
    206       best_class = i;
    207     }
    208   }
    209   return best_class;
    210 }
    211 
    212 int MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) {
    213   int best_class = -1;
    214   float best_score = -10000.0;
    215   //float best_score = -MathLimits<float>::kMax;
    216   for (int i = 0; i < num_classes_; ++i) {
    217     float score_i = SparseScore(inputs, parameters_[i]);
    218     if (score_i > best_score) {
    219       best_score = score_i;
    220       best_class = i;
    221     }
    222   }
    223   return best_class;
    224 }
    225 
    226 float MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) {
    227   int num_examples = data.size();
    228   float total_error = 0.0;
    229   for (int t = 0; t < num_examples; ++t) {
    230     int best_class = GetClass(data[t].first);
    231     if (best_class != data[t].second) {
    232       ++total_error;
    233     }
    234   }
    235   return total_error / num_examples;
    236 }
    237 
    238 float MulticlassPA::SparseTest(
    239     const vector<pair<vector<pair<int, float> >, int> >& data) {
    240   int num_examples = data.size();
    241   float total_error = 0.0;
    242   for (int t = 0; t < num_examples; ++t) {
    243     int best_class = SparseGetClass(data[t].first);
    244     if (best_class != data[t].second) {
    245       ++total_error;
    246     }
    247   }
    248   return total_error / num_examples;
    249 }
    250 }  // namespace learningfw
    251