1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // 18 // This file contains the MulticlassPA class which implements a simple 19 // linear multi-class classifier based on the multi-prototype version of 20 // passive aggressive. 21 22 #include "native/multiclass_pa.h" 23 24 using std::vector; 25 using std::pair; 26 27 namespace learningfw { 28 29 float RandFloat() { 30 return static_cast<float>(rand()) / RAND_MAX; 31 } 32 33 MulticlassPA::MulticlassPA(int num_classes, 34 int num_dimensions, 35 float aggressiveness) 36 : num_classes_(num_classes), 37 num_dimensions_(num_dimensions), 38 aggressiveness_(aggressiveness) { 39 InitializeParameters(); 40 } 41 42 MulticlassPA::~MulticlassPA() { 43 } 44 45 void MulticlassPA::InitializeParameters() { 46 parameters_.resize(num_classes_); 47 for (int i = 0; i < num_classes_; ++i) { 48 parameters_[i].resize(num_dimensions_); 49 for (int j = 0; j < num_dimensions_; ++j) { 50 parameters_[i][j] = 0.0; 51 } 52 } 53 } 54 55 int MulticlassPA::PickAClassExcept(int target) { 56 int picked; 57 do { 58 picked = static_cast<int>(RandFloat() * num_classes_); 59 // picked = static_cast<int>(random_.RandFloat() * num_classes_); 60 } while (target == picked); 61 return picked; 62 } 63 64 int MulticlassPA::PickAnExample(int num_examples) { 65 return static_cast<int>(RandFloat() * num_examples); 66 } 67 68 float MulticlassPA::Score(const vector<float>& inputs, 69 const vector<float>& parameters) const { 70 // CHECK_EQ(inputs.size(), parameters.size()); 71 float result = 0.0; 72 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 73 result += inputs[i] * parameters[i]; 74 } 75 return result; 76 } 77 78 float MulticlassPA::SparseScore(const vector<pair<int, float> >& inputs, 79 const vector<float>& parameters) const { 80 float result = 0.0; 81 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 82 //DCHECK_GE(inputs[i].first, 0); 83 //DCHECK_LT(inputs[i].first, parameters.size()); 84 result += inputs[i].second * parameters[inputs[i].first]; 85 } 86 return result; 87 } 88 89 float MulticlassPA::L2NormSquare(const vector<float>& inputs) const { 90 float norm = 0; 91 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 92 norm += inputs[i] * inputs[i]; 93 } 94 return norm; 95 } 96 97 float MulticlassPA::SparseL2NormSquare( 98 const vector<pair<int, float> >& inputs) const { 99 float norm = 0; 100 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 101 norm += inputs[i].second * inputs[i].second; 102 } 103 return norm; 104 } 105 106 float MulticlassPA::TrainOneExample(const vector<float>& inputs, int target) { 107 //CHECK_GE(target, 0); 108 //CHECK_LT(target, num_classes_); 109 float target_class_score = Score(inputs, parameters_[target]); 110 // VLOG(1) << "target class " << target << " score " << target_class_score; 111 int other_class = PickAClassExcept(target); 112 float other_class_score = Score(inputs, parameters_[other_class]); 113 // VLOG(1) << "other class " << other_class << " score " << other_class_score; 114 float loss = 1.0 - target_class_score + other_class_score; 115 if (loss > 0.0) { 116 // Compute the learning rate according to PA-I. 117 float twice_norm_square = L2NormSquare(inputs) * 2.0; 118 if (twice_norm_square == 0.0) { 119 twice_norm_square = kEpsilon; 120 } 121 float rate = loss / twice_norm_square; 122 if (rate > aggressiveness_) { 123 rate = aggressiveness_; 124 } 125 // VLOG(1) << "loss = " << loss << " rate = " << rate; 126 // Modify the parameter vectors of the correct and wrong classes 127 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 128 // First modify the parameter value of the correct class 129 parameters_[target][i] += rate * inputs[i]; 130 // Then modify the parameter value of the wrong class 131 parameters_[other_class][i] -= rate * inputs[i]; 132 } 133 return loss; 134 } 135 return 0.0; 136 } 137 138 float MulticlassPA::SparseTrainOneExample( 139 const vector<pair<int, float> >& inputs, int target) { 140 // CHECK_GE(target, 0); 141 // CHECK_LT(target, num_classes_); 142 float target_class_score = SparseScore(inputs, parameters_[target]); 143 // VLOG(1) << "target class " << target << " score " << target_class_score; 144 int other_class = PickAClassExcept(target); 145 float other_class_score = SparseScore(inputs, parameters_[other_class]); 146 // VLOG(1) << "other class " << other_class << " score " << other_class_score; 147 float loss = 1.0 - target_class_score + other_class_score; 148 if (loss > 0.0) { 149 // Compute the learning rate according to PA-I. 150 float twice_norm_square = SparseL2NormSquare(inputs) * 2.0; 151 if (twice_norm_square == 0.0) { 152 twice_norm_square = kEpsilon; 153 } 154 float rate = loss / twice_norm_square; 155 if (rate > aggressiveness_) { 156 rate = aggressiveness_; 157 } 158 // VLOG(1) << "loss = " << loss << " rate = " << rate; 159 // Modify the parameter vectors of the correct and wrong classes 160 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) { 161 // First modify the parameter value of the correct class 162 parameters_[target][inputs[i].first] += rate * inputs[i].second; 163 // Then modify the parameter value of the wrong class 164 parameters_[other_class][inputs[i].first] -= rate * inputs[i].second; 165 } 166 return loss; 167 } 168 return 0.0; 169 } 170 171 float MulticlassPA::Train(const vector<pair<vector<float>, int> >& data, 172 int num_iterations) { 173 int num_examples = data.size(); 174 float total_loss = 0.0; 175 for (int t = 0; t < num_iterations; ++t) { 176 int index = PickAnExample(num_examples); 177 float loss_t = TrainOneExample(data[index].first, data[index].second); 178 total_loss += loss_t; 179 } 180 return total_loss / static_cast<float>(num_iterations); 181 } 182 183 float MulticlassPA::SparseTrain( 184 const vector<pair<vector<pair<int, float> >, int> >& data, 185 int num_iterations) { 186 int num_examples = data.size(); 187 float total_loss = 0.0; 188 for (int t = 0; t < num_iterations; ++t) { 189 int index = PickAnExample(num_examples); 190 float loss_t = SparseTrainOneExample(data[index].first, data[index].second); 191 total_loss += loss_t; 192 } 193 return total_loss / static_cast<float>(num_iterations); 194 } 195 196 int MulticlassPA::GetClass(const vector<float>& inputs) { 197 int best_class = -1; 198 float best_score = -10000.0; 199 // float best_score = -MathLimits<float>::kMax; 200 for (int i = 0; i < num_classes_; ++i) { 201 float score_i = Score(inputs, parameters_[i]); 202 if (score_i > best_score) { 203 best_score = score_i; 204 best_class = i; 205 } 206 } 207 return best_class; 208 } 209 210 int MulticlassPA::SparseGetClass(const vector<pair<int, float> >& inputs) { 211 int best_class = -1; 212 float best_score = -10000.0; 213 //float best_score = -MathLimits<float>::kMax; 214 for (int i = 0; i < num_classes_; ++i) { 215 float score_i = SparseScore(inputs, parameters_[i]); 216 if (score_i > best_score) { 217 best_score = score_i; 218 best_class = i; 219 } 220 } 221 return best_class; 222 } 223 224 float MulticlassPA::Test(const vector<pair<vector<float>, int> >& data) { 225 int num_examples = data.size(); 226 float total_error = 0.0; 227 for (int t = 0; t < num_examples; ++t) { 228 int best_class = GetClass(data[t].first); 229 if (best_class != data[t].second) { 230 ++total_error; 231 } 232 } 233 return total_error / num_examples; 234 } 235 236 float MulticlassPA::SparseTest( 237 const vector<pair<vector<pair<int, float> >, int> >& data) { 238 int num_examples = data.size(); 239 float total_error = 0.0; 240 for (int t = 0; t < num_examples; ++t) { 241 int best_class = SparseGetClass(data[t].first); 242 if (best_class != data[t].second) { 243 ++total_error; 244 } 245 } 246 return total_error / num_examples; 247 } 248 } // namespace learningfw 249