Home | History | Annotate | Download | only in session
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/core/session/dic_traverse_session.h"
     18 
     19 #include "binary_format.h"
     20 #include "defines.h"
     21 #include "dictionary.h"
     22 #include "dic_traverse_wrapper.h"
     23 #include "jni.h"
     24 #include "suggest/core/dicnode/dic_node_utils.h"
     25 
     26 namespace latinime {
     27 
     28 const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20;
     29 
     30 // A factory method for DicTraverseSession
     31 static void *getSessionInstance(JNIEnv *env, jstring localeStr) {
     32     return new DicTraverseSession(env, localeStr);
     33 }
     34 
     35 // TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down.
     36 static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary,
     37         const int *prevWord, const int prevWordLength) {
     38     if (traverseSession) {
     39         DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession);
     40         tSession->init(dictionary, prevWord, prevWordLength);
     41     }
     42 }
     43 
     44 // TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down.
     45 static void releaseSessionInstance(void *traverseSession) {
     46     delete static_cast<DicTraverseSession *>(traverseSession);
     47 }
     48 
     49 // An ad-hoc internal class to register the factory method defined above
     50 class TraverseSessionFactoryRegisterer {
     51  public:
     52     TraverseSessionFactoryRegisterer() {
     53         DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance);
     54         DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance);
     55         DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance);
     56     }
     57  private:
     58     DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer);
     59 };
     60 
     61 // To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor.
     62 static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
     63 
     64 void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
     65         int prevWordLength) {
     66     mDictionary = dictionary;
     67     mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
     68             mDictionary->getDictSize());
     69     if (!prevWord) {
     70         mPrevWordPos = NOT_VALID_WORD;
     71         return;
     72     }
     73     // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
     74     mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord,
     75             prevWordLength, false /* forceLowerCaseSearch */);
     76     if (mPrevWordPos == NOT_VALID_WORD) {
     77         // Check bigrams for lower-cased previous word if original was not found. Useful for
     78         // auto-capitalized words like "The [current_word]".
     79         mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord,
     80                 prevWordLength, true /* forceLowerCaseSearch */);
     81     }
     82 }
     83 
     84 void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,
     85         const int *inputCodePoints, const int inputSize, const int *const inputXs,
     86         const int *const inputYs, const int *const times, const int *const pointerIds,
     87         const float maxSpatialDistance, const int maxPointerCount) {
     88     mProximityInfo = pInfo;
     89     mMaxPointerCount = maxPointerCount;
     90     initializeProximityInfoStates(inputCodePoints, inputXs, inputYs, times, pointerIds, inputSize,
     91             maxSpatialDistance, maxPointerCount);
     92 }
     93 
     94 const uint8_t *DicTraverseSession::getOffsetDict() const {
     95     return mDictionary->getOffsetDict();
     96 }
     97 
     98 int DicTraverseSession::getDictFlags() const {
     99     return mDictionary->getDictFlags();
    100 }
    101 
    102 void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
    103     mDicNodesCache.reset(nextActiveCacheSize, maxWords);
    104     mMultiBigramMap.clear();
    105     mPartiallyCommited = false;
    106 }
    107 
    108 void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints,
    109         const int *const inputXs, const int *const inputYs, const int *const times,
    110         const int *const pointerIds, const int inputSize, const float maxSpatialDistance,
    111         const int maxPointerCount) {
    112     ASSERT(1 <= maxPointerCount && maxPointerCount <= MAX_POINTER_COUNT_G);
    113     mInputSize = 0;
    114     for (int i = 0; i < maxPointerCount; ++i) {
    115         mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(),
    116                 inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds,
    117                 maxPointerCount == MAX_POINTER_COUNT_G
    118                 /* TODO: this is a hack. fix proximity info state */);
    119         mInputSize += mProximityInfoStates[i].size();
    120     }
    121 }
    122 } // namespace latinime
    123