1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Purpose: A container for sparse weight vectors 18 // Maintains the sparse vector as a list of (name, value) pairs alongwith 19 // a normalizer_. All operations assume that (name, value/normalizer_) is the 20 // true value in question. 21 22 #ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 23 #define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 24 25 #include <hash_map> 26 #include <iosfwd> 27 #include <math.h> 28 #include <sstream> 29 #include <string> 30 31 #include "common_defs.h" 32 33 namespace learning_stochastic_linear { 34 35 template<class Key = std::string, class Hash = std::hash_map<Key, double> > 36 class SparseWeightVector { 37 public: 38 typedef Hash Wmap; 39 typedef typename Wmap::iterator Witer; 40 typedef typename Wmap::const_iterator Witer_const; 41 SparseWeightVector() { 42 normalizer_ = 1.0; 43 } 44 ~SparseWeightVector() {} 45 explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) { 46 CopyFrom(other); 47 } 48 void operator=(const SparseWeightVector<Key, Hash> &other) { 49 CopyFrom(other); 50 } 51 void CopyFrom(const SparseWeightVector<Key, Hash> &other) { 52 w_ = other.w_; 53 wmin_ = other.wmin_; 54 wmax_ = other.wmax_; 55 normalizer_ = other.normalizer_; 56 } 57 58 // This function implements checks to prevent unbounded vectors. It returns 59 // true if the checks succeed and false otherwise. A vector is deemed invalid 60 // if any of these conditions are met: 61 // 1. it has no values. 62 // 2. its normalizer is nan or inf or close to zero. 63 // 3. any of its values are nan or inf. 64 // 4. its L0 norm is close to zero. 65 bool IsValid() const; 66 67 // Normalizer getters and setters. 68 double GetNormalizer() const { 69 return normalizer_; 70 } 71 void SetNormalizer(const double norm) { 72 normalizer_ = norm; 73 } 74 void NormalizerMultUpdate(const double mul) { 75 normalizer_ = normalizer_ * mul; 76 } 77 void NormalizerAddUpdate(const double add) { 78 normalizer_ += add; 79 } 80 81 // Divides all the values by the normalizer, then it resets it to 1.0 82 void ResetNormalizer(); 83 84 // Bound getters and setters. 85 // True if there is a bound with val containing the bound. false otherwise. 86 bool GetElementMinBound(const Key &fname, double *val) const { 87 return GetValue(wmin_, fname, val); 88 } 89 bool GetElementMaxBound(const Key &fname, double *val) const { 90 return GetValue(wmax_, fname, val); 91 } 92 void SetElementMinBound(const Key &fname, const double bound) { 93 wmin_[fname] = bound; 94 } 95 void SetElementMaxBound(const Key &fname, const double bound) { 96 wmax_[fname] = bound; 97 } 98 // Element getters and setters. 99 double GetElement(const Key &fname) const { 100 double val = 0; 101 GetValue(w_, fname, &val); 102 return val; 103 } 104 void SetElement(const Key &fname, const double val) { 105 //DCHECK(!isnan(val)); 106 w_[fname] = val; 107 } 108 void AddUpdateElement(const Key &fname, const double val) { 109 w_[fname] += val; 110 } 111 void MultUpdateElement(const Key &fname, const double val) { 112 w_[fname] *= val; 113 } 114 // Load another weight vectors. Will overwrite the current vector. 115 void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) { 116 w_.clear(); 117 w_.insert(vec.w_.begin(), vec.w_.end()); 118 wmax_.insert(vec.wmax_.begin(), vec.wmax_.end()); 119 wmin_.insert(vec.wmin_.begin(), vec.wmin_.end()); 120 normalizer_ = vec.normalizer_; 121 } 122 void Clear() { 123 w_.clear(); 124 wmax_.clear(); 125 wmin_.clear(); 126 } 127 const Wmap& GetMap() const { 128 return w_; 129 } 130 // Vector Operations. 131 void AdditiveWeightUpdate(const double multiplier, 132 const SparseWeightVector<Key, Hash> &w1, 133 const double additive_const); 134 void AdditiveSquaredWeightUpdate(const double multiplier, 135 const SparseWeightVector<Key, Hash> &w1, 136 const double additive_const); 137 void AdditiveInvSqrtWeightUpdate(const double multiplier, 138 const SparseWeightVector<Key, Hash> &w1, 139 const double additive_const); 140 void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1); 141 double DotProduct(const SparseWeightVector<Key, Hash> &s) const; 142 // L-x norm. eg. L1, L2. 143 double LxNorm(const double x) const; 144 double L2Norm() const; 145 double L1Norm() const; 146 double L0Norm(const double epsilon) const; 147 // Bound preserving updates. 148 void AdditiveWeightUpdateBounded(const double multiplier, 149 const SparseWeightVector<Key, Hash> &w1, 150 const double additive_const); 151 void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1); 152 void ReprojectToBounds(); 153 void ReprojectL0(const double l0_norm); 154 void ReprojectL1(const double l1_norm); 155 void ReprojectL2(const double l2_norm); 156 // Reproject using the given norm. 157 // Will also rescale regularizer_ if it gets too small/large. 158 int32 Reproject(const double norm, const RegularizationType r); 159 // Convert this vector to a string, simply for debugging. 160 std::string DebugString() const { 161 std::stringstream stream; 162 stream << *this; 163 return stream.str(); 164 } 165 private: 166 // The weight map. 167 Wmap w_; 168 // Constraint bounds. 169 Wmap wmin_; 170 Wmap wmax_; 171 // Normalizing constant in magnitude measurement. 172 double normalizer_; 173 // This function in necessary since by default hash_map inserts an element 174 // if it does not find the key through [] operator. It implements a lookup 175 // without the space overhead of an add. 176 bool GetValue(const Wmap &w1, const Key &fname, double *val) const { 177 Witer_const iter = w1.find(fname); 178 if (iter != w1.end()) { 179 (*val) = iter->second; 180 return true; 181 } else { 182 (*val) = 0; 183 return false; 184 } 185 } 186 }; 187 188 // Outputs a SparseWeightVector, for debugging. 189 template <class Key, class Hash> 190 std::ostream& operator<<(std::ostream &stream, 191 const SparseWeightVector<Key, Hash> &vector) { 192 typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap(); 193 stream << "[[ "; 194 for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin(); 195 iter != w_map.end(); 196 ++iter) { 197 stream << "<" << iter->first << ", " << iter->second << "> "; 198 } 199 return stream << " ]]"; 200 }; 201 202 } // namespace learning_stochastic_linear 203 #endif // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 204