Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "tokenizer.h"
     18 
     19 #include <algorithm>
     20 
     21 #include "util/base/logging.h"
     22 #include "util/strings/utf8.h"
     23 
     24 namespace libtextclassifier2 {
     25 
     26 Tokenizer::Tokenizer(
     27     const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
     28     bool split_on_script_change)
     29     : split_on_script_change_(split_on_script_change) {
     30   for (const TokenizationCodepointRange* range : codepoint_ranges) {
     31     codepoint_ranges_.emplace_back(range->UnPack());
     32   }
     33 
     34   std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
     35             [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
     36                const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
     37               return a->start < b->start;
     38             });
     39 }
     40 
     41 const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
     42     int codepoint) const {
     43   auto it = std::lower_bound(
     44       codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
     45       [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
     46          int codepoint) {
     47         // This function compares range with the codepoint for the purpose of
     48         // finding the first greater or equal range. Because of the use of
     49         // std::lower_bound it needs to return true when range < codepoint;
     50         // the first time it will return false the lower bound is found and
     51         // returned.
     52         //
     53         // It might seem weird that the condition is range.end <= codepoint
     54         // here but when codepoint == range.end it means it's actually just
     55         // outside of the range, thus the range is less than the codepoint.
     56         return range->end <= codepoint;
     57       });
     58   if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
     59       (*it)->end > codepoint) {
     60     return it->get();
     61   } else {
     62     return nullptr;
     63   }
     64 }
     65 
     66 void Tokenizer::GetScriptAndRole(char32 codepoint,
     67                                  TokenizationCodepointRange_::Role* role,
     68                                  int* script) const {
     69   const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
     70   if (range) {
     71     *role = range->role;
     72     *script = range->script_id;
     73   } else {
     74     *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
     75     *script = kUnknownScript;
     76   }
     77 }
     78 
     79 std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
     80   UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
     81   return Tokenize(text_unicode);
     82 }
     83 
     84 std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
     85   std::vector<Token> result;
     86   Token new_token("", 0, 0);
     87   int codepoint_index = 0;
     88 
     89   int last_script = kInvalidScript;
     90   for (auto it = text_unicode.begin(); it != text_unicode.end();
     91        ++it, ++codepoint_index) {
     92     TokenizationCodepointRange_::Role role;
     93     int script;
     94     GetScriptAndRole(*it, &role, &script);
     95 
     96     if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
     97         (split_on_script_change_ && last_script != kInvalidScript &&
     98          last_script != script)) {
     99       if (!new_token.value.empty()) {
    100         result.push_back(new_token);
    101       }
    102       new_token = Token("", codepoint_index, codepoint_index);
    103     }
    104     if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
    105       new_token.value += std::string(
    106           it.utf8_data(),
    107           it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
    108       ++new_token.end;
    109     }
    110     if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
    111       if (!new_token.value.empty()) {
    112         result.push_back(new_token);
    113       }
    114       new_token = Token("", codepoint_index + 1, codepoint_index + 1);
    115     }
    116 
    117     last_script = script;
    118   }
    119   if (!new_token.value.empty()) {
    120     result.push_back(new_token);
    121   }
    122 
    123   return result;
    124 }
    125 
    126 }  // namespace libtextclassifier2
    127