Home | History | Annotate | Download | only in utils
      1 /*
      2  * Copyright (C) 2013, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
     18 
     19 #include <algorithm>
     20 #include <cmath>
     21 #include <stdlib.h>
     22 
     23 #include "suggest/policyimpl/dictionary/header/header_policy.h"
     24 #include "suggest/policyimpl/dictionary/utils/probability_utils.h"
     25 #include "utils/time_keeper.h"
     26 
     27 namespace latinime {
     28 
     29 const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
     30 const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;
     31 
     32 const int ForgettingCurveUtils::MAX_LEVEL = 3;
     33 const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
     34 const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
     35 const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;
     36 
     37 const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
     38 const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
     39 
     40 const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable;
     41 
     42 // TODO: Revise the logic to decide the initial probability depending on the given probability.
     43 /* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo(
     44         const HistoricalInfo *const originalHistoricalInfo, const int newProbability,
     45         const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy) {
     46     const int timestamp = newHistoricalInfo->getTimeStamp();
     47     if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) {
     48         // Add entry as a valid word.
     49         const int level = clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel());
     50         const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
     51         return HistoricalInfo(timestamp, level, count);
     52     } else if (!originalHistoricalInfo->isValid()
     53             || originalHistoricalInfo->getLevel() < newHistoricalInfo->getLevel()
     54             || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
     55                     && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
     56         // Initial information.
     57         const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
     58         const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
     59         return HistoricalInfo(timestamp, level, count);
     60     } else {
     61         const int updatedCount = originalHistoricalInfo->getCount() + 1;
     62         if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
     63             // The count exceeds the max value the level can be incremented.
     64             if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
     65                 // The level is already max.
     66                 return HistoricalInfo(timestamp,
     67                         originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
     68             } else {
     69                 // Level up.
     70                 return HistoricalInfo(timestamp,
     71                         originalHistoricalInfo->getLevel() + 1, 0 /* count */);
     72             }
     73         } else {
     74             return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), updatedCount);
     75         }
     76     }
     77 }
     78 
     79 /* static */ int ForgettingCurveUtils::decodeProbability(
     80         const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) {
     81     const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimeStamp(),
     82             headerPolicy->getForgettingCurveDurationToLevelDown());
     83     return sProbabilityTable.getProbability(
     84             headerPolicy->getForgettingCurveProbabilityValuesTableId(),
     85             clampToValidLevelRange(historicalInfo->getLevel()),
     86             clampToValidTimeStepCountRange(elapsedTimeStepCount));
     87 }
     88 
     89 /* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
     90         const int bigramProbability) {
     91     if (unigramProbability == NOT_A_PROBABILITY) {
     92         return NOT_A_PROBABILITY;
     93     } else if (bigramProbability == NOT_A_PROBABILITY) {
     94         return std::min(backoff(unigramProbability), MAX_PROBABILITY);
     95     } else {
     96         // TODO: Investigate better way to handle bigram probability.
     97         return std::min(std::max(unigramProbability,
     98                 bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY);
     99     }
    100 }
    101 
    102 /* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo,
    103         const HeaderPolicy *const headerPolicy) {
    104     return historicalInfo->getLevel() > 0
    105             || getElapsedTimeStepCount(historicalInfo->getTimeStamp(),
    106                     headerPolicy->getForgettingCurveDurationToLevelDown())
    107                             < DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
    108 }
    109 
    110 /* static */ const HistoricalInfo ForgettingCurveUtils::createHistoricalInfoToSave(
    111         const HistoricalInfo *const originalHistoricalInfo,
    112         const HeaderPolicy *const headerPolicy) {
    113     if (originalHistoricalInfo->getTimeStamp() == NOT_A_TIMESTAMP) {
    114         return HistoricalInfo();
    115     }
    116     const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown();
    117     const int elapsedTimeStep = getElapsedTimeStepCount(
    118             originalHistoricalInfo->getTimeStamp(), durationToLevelDownInSeconds);
    119     if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
    120         // No need to update historical info.
    121         return *originalHistoricalInfo;
    122     }
    123     // Level down.
    124     const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
    125     const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ?
    126             originalHistoricalInfo->getLevel() : maxLevelDownAmonut;
    127     const int adjustedTimestampInSeconds = originalHistoricalInfo->getTimeStamp() +
    128             levelDownAmount * durationToLevelDownInSeconds;
    129     return HistoricalInfo(adjustedTimestampInSeconds,
    130             originalHistoricalInfo->getLevel() - levelDownAmount, 0 /* count */);
    131 }
    132 
    133 /* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
    134         const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy) {
    135     if (unigramCount >= getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) {
    136         // Unigram count exceeds the limit.
    137         return true;
    138     } else if (bigramCount >= getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) {
    139         // Bigram count exceeds the limit.
    140         return true;
    141     }
    142     if (mindsBlockByDecay) {
    143         return false;
    144     }
    145     if (headerPolicy->getLastDecayedTime() + DECAY_INTERVAL_SECONDS
    146             < TimeKeeper::peekCurrentTime()) {
    147         // Time to decay.
    148         return true;
    149     }
    150     return false;
    151 }
    152 
    153 // See comments in ProbabilityUtils::backoff().
    154 /* static */ int ForgettingCurveUtils::backoff(const int unigramProbability) {
    155     // See TODO comments in ForgettingCurveUtils::getProbability().
    156     return unigramProbability;
    157 }
    158 
    159 /* static */ int ForgettingCurveUtils::getElapsedTimeStepCount(const int timestamp,
    160         const int durationToLevelDownInSeconds) {
    161     const int elapsedTimeInSeconds = TimeKeeper::peekCurrentTime() - timestamp;
    162     const int timeStepDurationInSeconds =
    163             durationToLevelDownInSeconds / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
    164     return elapsedTimeInSeconds / timeStepDurationInSeconds;
    165 }
    166 
    167 /* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange(const int level) {
    168     return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL);
    169 }
    170 
    171 /* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
    172         const HeaderPolicy *const headerPolicy) {
    173     return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
    174 }
    175 
    176 /* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
    177     return std::min(std::max(level, 0), MAX_LEVEL);
    178 }
    179 
    180 /* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange(const int timeStepCount) {
    181     return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT);
    182 }
    183 
    184 const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4;
    185 const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0;
    186 const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = 1;
    187 const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2;
    188 const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
    189 const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
    190 const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32;
    191 const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35;
    192 const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40;
    193 
    194 
    195 ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
    196     mTables.resize(PROBABILITY_TABLE_COUNT);
    197     for (int tableId = 0; tableId < PROBABILITY_TABLE_COUNT; ++tableId) {
    198         mTables[tableId].resize(MAX_LEVEL + 1);
    199         for (int level = 0; level <= MAX_LEVEL; ++level) {
    200             mTables[tableId][level].resize(MAX_ELAPSED_TIME_STEP_COUNT + 1);
    201             const float initialProbability = getBaseProbabilityForLevel(tableId, level);
    202             const float endProbability = getBaseProbabilityForLevel(tableId, level - 1);
    203             for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
    204                     ++timeStepCount) {
    205                 if (level == 0) {
    206                     mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
    207                     continue;
    208                 }
    209                 const float probability = initialProbability
    210                         * powf(initialProbability / endProbability,
    211                                 -1.0f * static_cast<float>(timeStepCount)
    212                                         / static_cast<float>(MAX_ELAPSED_TIME_STEP_COUNT + 1));
    213                 mTables[tableId][level][timeStepCount] =
    214                         std::min(std::max(static_cast<int>(probability), 1), MAX_PROBABILITY);
    215             }
    216         }
    217     }
    218 }
    219 
    220 /* static */ int ForgettingCurveUtils::ProbabilityTable::getBaseProbabilityForLevel(
    221         const int tableId, const int level) {
    222     if (tableId == WEAK_PROBABILITY_TABLE_ID) {
    223         // Max probability is 127.
    224         return static_cast<float>(WEAK_MAX_PROBABILITY / (1 << (MAX_LEVEL - level)));
    225     } else if (tableId == MODEST_PROBABILITY_TABLE_ID) {
    226         // Max probability is 128.
    227         return static_cast<float>(MODEST_BASE_PROBABILITY * (level + 1));
    228     } else if (tableId == STRONG_PROBABILITY_TABLE_ID) {
    229         // Max probability is 140.
    230         return static_cast<float>(STRONG_BASE_PROBABILITY * (level + 1));
    231     } else if (tableId == AGGRESSIVE_PROBABILITY_TABLE_ID) {
    232         // Max probability is 160.
    233         return static_cast<float>(AGGRESSIVE_BASE_PROBABILITY * (level + 1));
    234     } else {
    235         return NOT_A_PROBABILITY;
    236     }
    237 }
    238 
    239 } // namespace latinime
    240