1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Implements learning rate adaptations common to most stochastic algorithms. 18 19 #ifndef LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 20 #define LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 21 22 #include <cmath> 23 #include "common_defs.h" 24 25 namespace learning_stochastic_linear { 26 27 class LearningRateController { 28 public: 29 LearningRateController() { 30 iteration_num_ = 1; 31 lambda_ = 1.0; 32 mini_batch_size_ = 1; 33 mini_batch_counter_ = 1; 34 sample_num_ = 1; 35 mode_ = INV_LINEAR; 36 is_first_sample_ = true; 37 } 38 ~LearningRateController() {} 39 // Getters and Setters for learning rate parameter lambda_ 40 double GetLambda() const { 41 return lambda_; 42 } 43 void SetLambda(double lambda) { 44 lambda_ = lambda; 45 } 46 // Operations on current iteration number 47 void SetIterationNumber(uint64 num) { 48 iteration_num_ = num; 49 } 50 void IncrementIteration() { 51 ++iteration_num_; 52 } 53 uint64 GetIterationNumber() const { 54 return iteration_num_; 55 } 56 // Mini batch operations 57 uint64 GetMiniBatchSize() const { 58 return mini_batch_size_; 59 } 60 void SetMiniBatchSize(uint64 size) { 61 //CHECK_GT(size, 0); 62 mini_batch_size_ = size; 63 } 64 void IncrementSample() { 65 // If this is the first sample we've already counted it to prevent NaNs 66 // in the learning rate computation 67 if (is_first_sample_) { 68 is_first_sample_ = false; 69 return; 70 } 71 ++sample_num_; 72 if (1 == mini_batch_size_) { 73 IncrementIteration(); 74 mini_batch_counter_ = 0; 75 } else { 76 ++mini_batch_counter_; 77 if ((mini_batch_counter_ % mini_batch_size_ == 0)) { 78 IncrementIteration(); 79 mini_batch_counter_ = 0; 80 } 81 } 82 } 83 uint64 GetMiniBatchCounter() const { 84 return mini_batch_counter_; 85 } 86 // Getters and setters for adaptation mode 87 AdaptationMode GetAdaptationMode() const { 88 return mode_; 89 } 90 void SetAdaptationMode(AdaptationMode m) { 91 mode_ = m; 92 } 93 double GetLearningRate() const { 94 if (mode_ == CONST) { 95 return (1.0 / (lambda_ * mini_batch_size_)); 96 } else if (mode_ == INV_LINEAR) { 97 return (1.0 / (lambda_ * iteration_num_ * mini_batch_size_)); 98 } else if (mode_ == INV_QUADRATIC) { 99 return (1.0 / (lambda_ * 100 mini_batch_size_ * 101 (static_cast<double>(iteration_num_) * iteration_num_))); 102 } else if (mode_ == INV_SQRT) { 103 return (1.0 / (lambda_ * 104 mini_batch_size_ * 105 sqrt((double)iteration_num_))); 106 } 107 return 0; 108 } 109 void CopyFrom(const LearningRateController &other) { 110 iteration_num_ = other.iteration_num_; 111 sample_num_ = other.sample_num_; 112 mini_batch_size_ = other.mini_batch_size_; 113 mini_batch_counter_ = other.mini_batch_counter_; 114 mode_ = other.mode_; 115 is_first_sample_ = other.is_first_sample_; 116 } 117 private: 118 uint64 iteration_num_; 119 uint64 sample_num_; 120 uint64 mini_batch_size_; 121 uint64 mini_batch_counter_; 122 double lambda_; 123 AdaptationMode mode_; 124 bool is_first_sample_; 125 }; 126 } // namespace learning_stochastic_linear 127 #endif // LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_ 128