Home | History | Annotate | Download | only in native
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Purpose: A container for sparse weight vectors
     18 // Maintains the sparse vector as a list of (name, value) pairs alongwith
     19 // a normalizer_. All operations assume that (name, value/normalizer_) is the
     20 // true value in question.
     21 
     22 #ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
     23 #define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
     24 
     25 #include <hash_map>
     26 #include <iosfwd>
     27 #include <math.h>
     28 #include <sstream>
     29 #include <string>
     30 
     31 #include "common_defs.h"
     32 
     33 namespace learning_stochastic_linear {
     34 
     35 template<class Key = std::string, class Hash = std::hash_map<Key, double> >
     36 class SparseWeightVector {
     37  public:
     38   typedef Hash Wmap;
     39   typedef typename Wmap::iterator Witer;
     40   typedef typename Wmap::const_iterator Witer_const;
     41   SparseWeightVector() {
     42     normalizer_ = 1.0;
     43   }
     44   ~SparseWeightVector() {}
     45   explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
     46     CopyFrom(other);
     47   }
     48   void operator=(const SparseWeightVector<Key, Hash> &other) {
     49     CopyFrom(other);
     50   }
     51   void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
     52     w_ = other.w_;
     53     wmin_ = other.wmin_;
     54     wmax_ = other.wmax_;
     55     normalizer_ = other.normalizer_;
     56   }
     57 
     58   // This function implements checks to prevent unbounded vectors. It returns
     59   // true if the checks succeed and false otherwise. A vector is deemed invalid
     60   // if any of these conditions are met:
     61   // 1. it has no values.
     62   // 2. its normalizer is nan or inf or close to zero.
     63   // 3. any of its values are nan or inf.
     64   // 4. its L0 norm is close to zero.
     65   bool IsValid() const;
     66 
     67   // Normalizer getters and setters.
     68   double GetNormalizer() const {
     69     return normalizer_;
     70   }
     71   void SetNormalizer(const double norm) {
     72     normalizer_ = norm;
     73   }
     74   void NormalizerMultUpdate(const double mul) {
     75     normalizer_ = normalizer_ * mul;
     76   }
     77   void NormalizerAddUpdate(const double add) {
     78     normalizer_ += add;
     79   }
     80 
     81   // Divides all the values by the normalizer, then it resets it to 1.0
     82   void ResetNormalizer();
     83 
     84   // Bound getters and setters.
     85   // True if there is a bound with val containing the bound. false otherwise.
     86   bool GetElementMinBound(const Key &fname, double *val) const {
     87     return GetValue(wmin_, fname, val);
     88   }
     89   bool GetElementMaxBound(const Key &fname, double *val) const {
     90     return GetValue(wmax_, fname, val);
     91   }
     92   void SetElementMinBound(const Key &fname, const double bound) {
     93     wmin_[fname] = bound;
     94   }
     95   void SetElementMaxBound(const Key &fname, const double bound) {
     96     wmax_[fname] = bound;
     97   }
     98   // Element getters and setters.
     99   double GetElement(const Key &fname) const {
    100     double val = 0;
    101     GetValue(w_, fname, &val);
    102     return val;
    103   }
    104   void SetElement(const Key &fname, const double val) {
    105     //DCHECK(!isnan(val));
    106     w_[fname] = val;
    107   }
    108   void AddUpdateElement(const Key &fname, const double val) {
    109     w_[fname] += val;
    110   }
    111   void MultUpdateElement(const Key &fname, const double val) {
    112     w_[fname] *= val;
    113   }
    114   // Load another weight vectors. Will overwrite the current vector.
    115   void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
    116     w_.clear();
    117     w_.insert(vec.w_.begin(), vec.w_.end());
    118     wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
    119     wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
    120     normalizer_ = vec.normalizer_;
    121   }
    122   void Clear() {
    123     w_.clear();
    124     wmax_.clear();
    125     wmin_.clear();
    126   }
    127   const Wmap& GetMap() const {
    128     return w_;
    129   }
    130   // Vector Operations.
    131   void AdditiveWeightUpdate(const double multiplier,
    132                             const SparseWeightVector<Key, Hash> &w1,
    133                             const double additive_const);
    134   void AdditiveSquaredWeightUpdate(const double multiplier,
    135                                    const SparseWeightVector<Key, Hash> &w1,
    136                                    const double additive_const);
    137   void AdditiveInvSqrtWeightUpdate(const double multiplier,
    138                                    const SparseWeightVector<Key, Hash> &w1,
    139                                    const double additive_const);
    140   void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
    141   double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
    142   // L-x norm. eg. L1, L2.
    143   double LxNorm(const double x) const;
    144   double L2Norm() const;
    145   double L1Norm() const;
    146   double L0Norm(const double epsilon) const;
    147   // Bound preserving updates.
    148   void AdditiveWeightUpdateBounded(const double multiplier,
    149                                    const SparseWeightVector<Key, Hash> &w1,
    150                                    const double additive_const);
    151   void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
    152   void ReprojectToBounds();
    153   void ReprojectL0(const double l0_norm);
    154   void ReprojectL1(const double l1_norm);
    155   void ReprojectL2(const double l2_norm);
    156   // Reproject using the given norm.
    157   // Will also rescale regularizer_ if it gets too small/large.
    158   int32 Reproject(const double norm, const RegularizationType r);
    159   // Convert this vector to a string, simply for debugging.
    160   std::string DebugString() const {
    161     std::stringstream stream;
    162     stream << *this;
    163     return stream.str();
    164   }
    165  private:
    166   // The weight map.
    167   Wmap w_;
    168   // Constraint bounds.
    169   Wmap wmin_;
    170   Wmap wmax_;
    171   // Normalizing constant in magnitude measurement.
    172   double normalizer_;
    173   // This function in necessary since by default hash_map inserts an element
    174   // if it does not find the key through [] operator. It implements a lookup
    175   // without the space overhead of an add.
    176   bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
    177     Witer_const iter = w1.find(fname);
    178     if (iter != w1.end()) {
    179       (*val) = iter->second;
    180       return true;
    181     } else {
    182       (*val) = 0;
    183       return false;
    184     }
    185   }
    186 };
    187 
    188 // Outputs a SparseWeightVector, for debugging.
    189 template <class Key, class Hash>
    190 std::ostream& operator<<(std::ostream &stream,
    191                     const SparseWeightVector<Key, Hash> &vector) {
    192   typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
    193   stream << "[[ ";
    194   for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
    195        iter != w_map.end();
    196        ++iter) {
    197     stream << "<" << iter->first << ", " << iter->second << "> ";
    198   }
    199   return stream << " ]]";
    200 };
    201 
    202 }  // namespace learning_stochastic_linear
    203 #endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
    204