Home | History | Annotate | Download | only in native
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Purpose: A container for sparse weight vectors
     18 // Maintains the sparse vector as a list of (name, value) pairs alongwith
     19 // a normalizer_. All operations assume that (name, value/normalizer_) is the
     20 // true value in question.
     21 
     22 #ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
     23 #define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
     24 
     25 #include <math.h>
     26 
     27 #include <iosfwd>
     28 #include <sstream>
     29 #include <string>
     30 #include <unordered_map>
     31 
     32 #include "common_defs.h"
     33 
     34 namespace learning_stochastic_linear {
     35 
     36 template<class Key = std::string, class Hash = std::unordered_map<Key, double> >
     37 class SparseWeightVector {
     38  public:
     39   typedef Hash Wmap;
     40   typedef typename Wmap::iterator Witer;
     41   typedef typename Wmap::const_iterator Witer_const;
     42   SparseWeightVector() {
     43     normalizer_ = 1.0;
     44   }
     45   ~SparseWeightVector() {}
     46   explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
     47     CopyFrom(other);
     48   }
     49   void operator=(const SparseWeightVector<Key, Hash> &other) {
     50     CopyFrom(other);
     51   }
     52   void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
     53     w_ = other.w_;
     54     wmin_ = other.wmin_;
     55     wmax_ = other.wmax_;
     56     normalizer_ = other.normalizer_;
     57   }
     58 
     59   // This function implements checks to prevent unbounded vectors. It returns
     60   // true if the checks succeed and false otherwise. A vector is deemed invalid
     61   // if any of these conditions are met:
     62   // 1. it has no values.
     63   // 2. its normalizer is nan or inf or close to zero.
     64   // 3. any of its values are nan or inf.
     65   // 4. its L0 norm is close to zero.
     66   bool IsValid() const;
     67 
     68   // Normalizer getters and setters.
     69   double GetNormalizer() const {
     70     return normalizer_;
     71   }
     72   void SetNormalizer(const double norm) {
     73     normalizer_ = norm;
     74   }
     75   void NormalizerMultUpdate(const double mul) {
     76     normalizer_ = normalizer_ * mul;
     77   }
     78   void NormalizerAddUpdate(const double add) {
     79     normalizer_ += add;
     80   }
     81 
     82   // Divides all the values by the normalizer, then it resets it to 1.0
     83   void ResetNormalizer();
     84 
     85   // Bound getters and setters.
     86   // True if there is a bound with val containing the bound. false otherwise.
     87   bool GetElementMinBound(const Key &fname, double *val) const {
     88     return GetValue(wmin_, fname, val);
     89   }
     90   bool GetElementMaxBound(const Key &fname, double *val) const {
     91     return GetValue(wmax_, fname, val);
     92   }
     93   void SetElementMinBound(const Key &fname, const double bound) {
     94     wmin_[fname] = bound;
     95   }
     96   void SetElementMaxBound(const Key &fname, const double bound) {
     97     wmax_[fname] = bound;
     98   }
     99   // Element getters and setters.
    100   double GetElement(const Key &fname) const {
    101     double val = 0;
    102     GetValue(w_, fname, &val);
    103     return val;
    104   }
    105   void SetElement(const Key &fname, const double val) {
    106     //DCHECK(!isnan(val));
    107     w_[fname] = val;
    108   }
    109   void AddUpdateElement(const Key &fname, const double val) {
    110     w_[fname] += val;
    111   }
    112   void MultUpdateElement(const Key &fname, const double val) {
    113     w_[fname] *= val;
    114   }
    115   // Load another weight vectors. Will overwrite the current vector.
    116   void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
    117     w_.clear();
    118     w_.insert(vec.w_.begin(), vec.w_.end());
    119     wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
    120     wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
    121     normalizer_ = vec.normalizer_;
    122   }
    123   void Clear() {
    124     w_.clear();
    125     wmax_.clear();
    126     wmin_.clear();
    127   }
    128   const Wmap& GetMap() const {
    129     return w_;
    130   }
    131   // Vector Operations.
    132   void AdditiveWeightUpdate(const double multiplier,
    133                             const SparseWeightVector<Key, Hash> &w1,
    134                             const double additive_const);
    135   void AdditiveSquaredWeightUpdate(const double multiplier,
    136                                    const SparseWeightVector<Key, Hash> &w1,
    137                                    const double additive_const);
    138   void AdditiveInvSqrtWeightUpdate(const double multiplier,
    139                                    const SparseWeightVector<Key, Hash> &w1,
    140                                    const double additive_const);
    141   void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
    142   double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
    143   // L-x norm. eg. L1, L2.
    144   double LxNorm(const double x) const;
    145   double L2Norm() const;
    146   double L1Norm() const;
    147   double L0Norm(const double epsilon) const;
    148   // Bound preserving updates.
    149   void AdditiveWeightUpdateBounded(const double multiplier,
    150                                    const SparseWeightVector<Key, Hash> &w1,
    151                                    const double additive_const);
    152   void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
    153   void ReprojectToBounds();
    154   void ReprojectL0(const double l0_norm);
    155   void ReprojectL1(const double l1_norm);
    156   void ReprojectL2(const double l2_norm);
    157   // Reproject using the given norm.
    158   // Will also rescale regularizer_ if it gets too small/large.
    159   int32 Reproject(const double norm, const RegularizationType r);
    160   // Convert this vector to a string, simply for debugging.
    161   std::string DebugString() const {
    162     std::stringstream stream;
    163     stream << *this;
    164     return stream.str();
    165   }
    166  private:
    167   // The weight map.
    168   Wmap w_;
    169   // Constraint bounds.
    170   Wmap wmin_;
    171   Wmap wmax_;
    172   // Normalizing constant in magnitude measurement.
    173   double normalizer_;
    174   // This function is necessary since by default unordered_map inserts an
    175   // element if it does not find the key through [] operator. It implements a
    176   // lookup without the space overhead of an add.
    177   bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
    178     Witer_const iter = w1.find(fname);
    179     if (iter != w1.end()) {
    180       (*val) = iter->second;
    181       return true;
    182     } else {
    183       (*val) = 0;
    184       return false;
    185     }
    186   }
    187 };
    188 
    189 // Outputs a SparseWeightVector, for debugging.
    190 template <class Key, class Hash>
    191 std::ostream& operator<<(std::ostream &stream,
    192                     const SparseWeightVector<Key, Hash> &vector) {
    193   typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
    194   stream << "[[ ";
    195   for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
    196        iter != w_map.end();
    197        ++iter) {
    198     stream << "<" << iter->first << ", " << iter->second << "> ";
    199   }
    200   return stream << " ]]";
    201 };
    202 
    203 }  // namespace learning_stochastic_linear
    204 #endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
    205