Home | History | Annotate | Download | only in native
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Implements learning rate adaptations common to most stochastic algorithms.
     18 
     19 #ifndef LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
     20 #define LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
     21 
     22 #include <cmath>
     23 #include "common_defs.h"
     24 
     25 namespace learning_stochastic_linear {
     26 
     27 class LearningRateController {
     28  public:
     29   LearningRateController() {
     30     iteration_num_ = 1;
     31     lambda_ = 1.0;
     32     mini_batch_size_ = 1;
     33     mini_batch_counter_ = 1;
     34     sample_num_ = 1;
     35     mode_ = INV_LINEAR;
     36     is_first_sample_ = true;
     37   }
     38   ~LearningRateController() {}
     39   // Getters and Setters for learning rate parameter lambda_
     40   double GetLambda() const {
     41     return lambda_;
     42   }
     43   void SetLambda(double lambda) {
     44     lambda_ = lambda;
     45   }
     46   // Operations on current iteration number
     47   void SetIterationNumber(uint64 num) {
     48     iteration_num_ = num;
     49   }
     50   void IncrementIteration() {
     51     ++iteration_num_;
     52   }
     53   uint64 GetIterationNumber() const {
     54     return iteration_num_;
     55   }
     56   // Mini batch operations
     57   uint64 GetMiniBatchSize() const {
     58     return mini_batch_size_;
     59   }
     60   void SetMiniBatchSize(uint64 size) {
     61     //CHECK_GT(size, 0);
     62     mini_batch_size_ = size;
     63   }
     64   void IncrementSample() {
     65     // If this is the first sample we've already counted it to prevent NaNs
     66     // in the learning rate computation
     67     if (is_first_sample_) {
     68       is_first_sample_ = false;
     69       return;
     70     }
     71     ++sample_num_;
     72     if (1 == mini_batch_size_) {
     73       IncrementIteration();
     74       mini_batch_counter_ = 0;
     75     } else {
     76       ++mini_batch_counter_;
     77       if ((mini_batch_counter_ % mini_batch_size_ == 0)) {
     78         IncrementIteration();
     79         mini_batch_counter_ = 0;
     80       }
     81     }
     82   }
     83   uint64 GetMiniBatchCounter() const {
     84     return mini_batch_counter_;
     85   }
     86   // Getters and setters for adaptation mode
     87   AdaptationMode GetAdaptationMode() const {
     88     return mode_;
     89   }
     90   void SetAdaptationMode(AdaptationMode m) {
     91     mode_ = m;
     92   }
     93   double GetLearningRate() const {
     94     if (mode_ == CONST) {
     95       return (1.0 / (lambda_ * mini_batch_size_));
     96     } else if (mode_ == INV_LINEAR) {
     97       return (1.0 / (lambda_ * iteration_num_ * mini_batch_size_));
     98     } else if (mode_ == INV_QUADRATIC) {
     99       return (1.0 / (lambda_ *
    100                      mini_batch_size_ *
    101                      (static_cast<double>(iteration_num_) * iteration_num_)));
    102     } else if (mode_ == INV_SQRT) {
    103       return (1.0 / (lambda_ *
    104                      mini_batch_size_ *
    105                      sqrt((double)iteration_num_)));
    106     }
    107     return 0;
    108   }
    109   void CopyFrom(const LearningRateController &other) {
    110     iteration_num_ = other.iteration_num_;
    111     sample_num_ = other.sample_num_;
    112     mini_batch_size_ = other.mini_batch_size_;
    113     mini_batch_counter_ = other.mini_batch_counter_;
    114     mode_ = other.mode_;
    115     is_first_sample_ = other.is_first_sample_;
    116   }
    117  private:
    118   uint64 iteration_num_;
    119   uint64 sample_num_;
    120   uint64 mini_batch_size_;
    121   uint64 mini_batch_counter_;
    122   double lambda_;
    123   AdaptationMode mode_;
    124   bool is_first_sample_;
    125 };
    126 }  // namespace learning_stochastic_linear
    127 #endif  // LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
    128