1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // 18 // This file contains the MulticlassPA class which implements a simple 19 // linear multi-class classifier based on the multi-prototype version of 20 // passive aggressive. 21 22 #include "native/multiclass_pa.h" 23 24 #include <stdlib.h> 25 26 using std::vector; 27 using std::pair; 28 29 namespace learningfw { 30 31 float RandFloat() { 32 return static_cast<float>(rand()) / RAND_MAX; 33 } 34 35 MulticlassPA::MulticlassPA(int num_classes, 36 int num_dimensions, 37 float aggressiveness) 38 : num_classes_(num_classes), 39 num_dimensions_(num_dimensions), 40 aggressiveness_(aggressiveness) { 41 InitializeParameters(); 42 } 43 44 MulticlassPA::~MulticlassPA() { 45 } 46 47 void MulticlassPA::InitializeParameters() { 48 parameters_.resize(num_classes_); 49 for (int i = 0; i < num_classes_; ++i) { 50 parameters_[i].resize(num_dimensions_); 51 for (int j = 0; j < num_dimensions_; ++j) { 52 parameters_[i][j] = 0.0; 53 } 54 } 55 } 56 57 int MulticlassPA::PickAClassExcept(int target) { 58 int picked; 59 do { 60 picked = static_cast<int>(RandFloat() * num_classes_); 61 // picked = static_cast<int>(random_.RandFloat() * num_classes_); 62 } while (target == picked); 63 return picked; 64 } 65 66 int MulticlassPA::PickAnExample(int num_examples) { 67 return static_cast<int>(RandFloat() * num_examples); 68 } 69 70 float MulticlassPA::Score(const vector<float>& inputs, 71 const vector<float>& parameters) const { 72 // CHECK_EQ(inputs.size(), parameters.size()); 73 float result = 0.0; 74 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 75 result += inputs[i] * parameters[i]; 76 } 77 return result; 78 } 79 80 float MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs, 81 const vector<float>& parameters) const { 82 float result = 0.0; 83 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 84 //DCHECK_GE(inputs[i].first, 0); 85 //DCHECK_LT(inputs[i].first, parameters.size()); 86 result += inputs[i].second * parameters[inputs[i].first]; 87 } 88 return result; 89 } 90 91 float MulticlassPA::L2NormSquare(const vector<float>& inputs) const { 92 float norm = 0; 93 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 94 norm += inputs[i] * inputs[i]; 95 } 96 return norm; 97 } 98 99 float MulticlassPA::SparseL2NormSquare( 100 const vector<pair<int, float> >& inputs) const { 101 float norm = 0; 102 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 103 norm += inputs[i].second * inputs[i].second; 104 } 105 return norm; 106 } 107 108 float MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) { 109 //CHECK_GE(target, 0); 110 //CHECK_LT(target, num_classes_); 111 float target_class_score = Score(inputs, parameters_[target]); 112 // VLOG(1) << "target class " << target << " score " << target_class_score; 113 int other_class = PickAClassExcept(target); 114 float other_class_score = Score(inputs, parameters_[other_class]); 115 // VLOG(1) << "other class " << other_class << " score " << other_class_score; 116 float loss = 1.0 - target_class_score + other_class_score; 117 if (loss > 0.0) { 118 // Compute the learning rate according to PA-I. 119 float twice_norm_square = L2NormSquare(inputs) * 2.0; 120 if (twice_norm_square == 0.0) { 121 twice_norm_square = kEpsilon; 122 } 123 float rate = loss / twice_norm_square; 124 if (rate > aggressiveness_) { 125 rate = aggressiveness_; 126 } 127 // VLOG(1) << "loss = " << loss << " rate = " << rate; 128 // Modify the parameter vectors of the correct and wrong classes 129 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 130 // First modify the parameter value of the correct class 131 parameters_[target][i] += rate * inputs[i]; 132 // Then modify the parameter value of the wrong class 133 parameters_[other_class][i] -= rate * inputs[i]; 134 } 135 return loss; 136 } 137 return 0.0; 138 } 139 140 float MulticlassPA::SparseTrainOneExample( 141 const vector<pair<int, float> >& inputs, int target) { 142 // CHECK_GE(target, 0); 143 // CHECK_LT(target, num_classes_); 144 float target_class_score = SparseScore(inputs, parameters_[target]); 145 // VLOG(1) << "target class " << target << " score " << target_class_score; 146 int other_class = PickAClassExcept(target); 147 float other_class_score = SparseScore(inputs, parameters_[other_class]); 148 // VLOG(1) << "other class " << other_class << " score " << other_class_score; 149 float loss = 1.0 - target_class_score + other_class_score; 150 if (loss > 0.0) { 151 // Compute the learning rate according to PA-I. 152 float twice_norm_square = SparseL2NormSquare(inputs) * 2.0; 153 if (twice_norm_square == 0.0) { 154 twice_norm_square = kEpsilon; 155 } 156 float rate = loss / twice_norm_square; 157 if (rate > aggressiveness_) { 158 rate = aggressiveness_; 159 } 160 // VLOG(1) << "loss = " << loss << " rate = " << rate; 161 // Modify the parameter vectors of the correct and wrong classes 162 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 163 // First modify the parameter value of the correct class 164 parameters_[target][inputs[i].first] += rate * inputs[i].second; 165 // Then modify the parameter value of the wrong class 166 parameters_[other_class][inputs[i].first] -= rate * inputs[i].second; 167 } 168 return loss; 169 } 170 return 0.0; 171 } 172 173 float MulticlassPA::Train(const vector<pair<vector<float>, int> >& data, 174 int num_iterations) { 175 int num_examples = data.size(); 176 float total_loss = 0.0; 177 for (int t = 0; t < num_iterations; ++t) { 178 int index = PickAnExample(num_examples); 179 float loss_t = TrainOneExample(data[index].first, data[index].second); 180 total_loss += loss_t; 181 } 182 return total_loss / static_cast<float>(num_iterations); 183 } 184 185 float MulticlassPA::SparseTrain( 186 const vector<pair<vector<pair<int, float> >, int> >& data, 187 int num_iterations) { 188 int num_examples = data.size(); 189 float total_loss = 0.0; 190 for (int t = 0; t < num_iterations; ++t) { 191 int index = PickAnExample(num_examples); 192 float loss_t = SparseTrainOneExample(data[index].first, data[index].second); 193 total_loss += loss_t; 194 } 195 return total_loss / static_cast<float>(num_iterations); 196 } 197 198 int MulticlassPA::GetClass(const vector<float>& inputs) { 199 int best_class = -1; 200 float best_score = -10000.0; 201 // float best_score = -MathLimits<float>::kMax; 202 for (int i = 0; i < num_classes_; ++i) { 203 float score_i = Score(inputs, parameters_[i]); 204 if (score_i > best_score) { 205 best_score = score_i; 206 best_class = i; 207 } 208 } 209 return best_class; 210 } 211 212 int MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) { 213 int best_class = -1; 214 float best_score = -10000.0; 215 //float best_score = -MathLimits<float>::kMax; 216 for (int i = 0; i < num_classes_; ++i) { 217 float score_i = SparseScore(inputs, parameters_[i]); 218 if (score_i > best_score) { 219 best_score = score_i; 220 best_class = i; 221 } 222 } 223 return best_class; 224 } 225 226 float MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) { 227 int num_examples = data.size(); 228 float total_error = 0.0; 229 for (int t = 0; t < num_examples; ++t) { 230 int best_class = GetClass(data[t].first); 231 if (best_class != data[t].second) { 232 ++total_error; 233 } 234 } 235 return total_error / num_examples; 236 } 237 238 float MulticlassPA::SparseTest( 239 const vector<pair<vector<pair<int, float> >, int> >& data) { 240 int num_examples = data.size(); 241 float total_error = 0.0; 242 for (int t = 0; t < num_examples; ++t) { 243 int best_class = SparseGetClass(data[t].first); 244 if (best_class != data[t].second) { 245 ++total_error; 246 } 247 } 248 return total_error / num_examples; 249 } 250 } // namespace learningfw 251