1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "utils/tokenizer.h" 18 19 #include <algorithm> 20 21 #include "utils/base/logging.h" 22 #include "utils/base/macros.h" 23 #include "utils/strings/utf8.h" 24 25 namespace libtextclassifier3 { 26 27 Tokenizer::Tokenizer( 28 const TokenizationType type, const UniLib* unilib, 29 const std::vector<const TokenizationCodepointRange*>& codepoint_ranges, 30 const std::vector<const CodepointRange*>& 31 internal_tokenizer_codepoint_ranges, 32 const bool split_on_script_change, 33 const bool icu_preserve_whitespace_tokens) 34 : type_(type), 35 unilib_(unilib), 36 split_on_script_change_(split_on_script_change), 37 icu_preserve_whitespace_tokens_(icu_preserve_whitespace_tokens) { 38 for (const TokenizationCodepointRange* range : codepoint_ranges) { 39 codepoint_ranges_.emplace_back(range->UnPack()); 40 } 41 42 std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(), 43 [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, 44 const std::unique_ptr<const TokenizationCodepointRangeT>& b) { 45 return a->start < b->start; 46 }); 47 48 SortCodepointRanges(internal_tokenizer_codepoint_ranges, 49 &internal_tokenizer_codepoint_ranges_); 50 } 51 52 const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange( 53 int codepoint) const { 54 auto it = std::lower_bound( 55 codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint, 56 [](const std::unique_ptr<const TokenizationCodepointRangeT>& range, 57 int codepoint) { 58 // This function compares range with the codepoint for the purpose of 59 // finding the first greater or equal range. Because of the use of 60 // std::lower_bound it needs to return true when range < codepoint; 61 // the first time it will return false the lower bound is found and 62 // returned. 63 // 64 // It might seem weird that the condition is range.end <= codepoint 65 // here but when codepoint == range.end it means it's actually just 66 // outside of the range, thus the range is less than the codepoint. 67 return range->end <= codepoint; 68 }); 69 if (it != codepoint_ranges_.end() && (*it)->start <= codepoint && 70 (*it)->end > codepoint) { 71 return it->get(); 72 } else { 73 return nullptr; 74 } 75 } 76 77 void Tokenizer::GetScriptAndRole(char32 codepoint, 78 TokenizationCodepointRange_::Role* role, 79 int* script) const { 80 const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint); 81 if (range) { 82 *role = range->role; 83 *script = range->script_id; 84 } else { 85 *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE; 86 *script = kUnknownScript; 87 } 88 } 89 90 std::vector<Token> Tokenizer::Tokenize(const std::string& text) const { 91 UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false); 92 return Tokenize(text_unicode); 93 } 94 95 std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const { 96 switch (type_) { 97 case TokenizationType_INTERNAL_TOKENIZER: 98 return InternalTokenize(text_unicode); 99 case TokenizationType_ICU: 100 TC3_FALLTHROUGH_INTENDED; 101 case TokenizationType_MIXED: { 102 std::vector<Token> result; 103 if (!ICUTokenize(text_unicode, &result)) { 104 return {}; 105 } 106 if (type_ == TokenizationType_MIXED) { 107 InternalRetokenize(text_unicode, &result); 108 } 109 return result; 110 } 111 default: 112 TC3_LOG(ERROR) << "Unknown tokenization type specified. Using internal."; 113 return InternalTokenize(text_unicode); 114 } 115 } 116 117 std::vector<Token> Tokenizer::InternalTokenize( 118 const UnicodeText& text_unicode) const { 119 std::vector<Token> result; 120 Token new_token("", 0, 0); 121 int codepoint_index = 0; 122 123 int last_script = kInvalidScript; 124 for (auto it = text_unicode.begin(); it != text_unicode.end(); 125 ++it, ++codepoint_index) { 126 TokenizationCodepointRange_::Role role; 127 int script; 128 GetScriptAndRole(*it, &role, &script); 129 130 if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE || 131 (split_on_script_change_ && last_script != kInvalidScript && 132 last_script != script)) { 133 if (!new_token.value.empty()) { 134 result.push_back(new_token); 135 } 136 new_token = Token("", codepoint_index, codepoint_index); 137 } 138 if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) { 139 new_token.value += std::string( 140 it.utf8_data(), 141 it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data())); 142 ++new_token.end; 143 } 144 if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) { 145 if (!new_token.value.empty()) { 146 result.push_back(new_token); 147 } 148 new_token = Token("", codepoint_index + 1, codepoint_index + 1); 149 } 150 151 last_script = script; 152 } 153 if (!new_token.value.empty()) { 154 result.push_back(new_token); 155 } 156 157 return result; 158 } 159 160 void Tokenizer::TokenizeSubstring(const UnicodeText& unicode_text, 161 CodepointSpan span, 162 std::vector<Token>* result) const { 163 if (span.first < 0) { 164 // There is no span to tokenize. 165 return; 166 } 167 168 // Extract the substring. 169 UnicodeText text = UnicodeText::Substring(unicode_text, span.first, 170 span.second, /*do_copy=*/false); 171 172 // Run the tokenizer and update the token bounds to reflect the offset of the 173 // substring. 174 std::vector<Token> tokens = InternalTokenize(text); 175 176 // Avoids progressive capacity increases in the for loop. 177 result->reserve(result->size() + tokens.size()); 178 for (Token& token : tokens) { 179 token.start += span.first; 180 token.end += span.first; 181 result->emplace_back(std::move(token)); 182 } 183 } 184 185 void Tokenizer::InternalRetokenize(const UnicodeText& unicode_text, 186 std::vector<Token>* tokens) const { 187 std::vector<Token> result; 188 CodepointSpan span(-1, -1); 189 for (Token& token : *tokens) { 190 const UnicodeText unicode_token_value = 191 UTF8ToUnicodeText(token.value, /*do_copy=*/false); 192 bool should_retokenize = true; 193 for (const int codepoint : unicode_token_value) { 194 if (!IsCodepointInRanges(codepoint, 195 internal_tokenizer_codepoint_ranges_)) { 196 should_retokenize = false; 197 break; 198 } 199 } 200 201 if (should_retokenize) { 202 if (span.first < 0) { 203 span.first = token.start; 204 } 205 span.second = token.end; 206 } else { 207 TokenizeSubstring(unicode_text, span, &result); 208 span.first = -1; 209 result.emplace_back(std::move(token)); 210 } 211 } 212 TokenizeSubstring(unicode_text, span, &result); 213 214 *tokens = std::move(result); 215 } 216 217 bool Tokenizer::ICUTokenize(const UnicodeText& context_unicode, 218 std::vector<Token>* result) const { 219 std::unique_ptr<UniLib::BreakIterator> break_iterator = 220 unilib_->CreateBreakIterator(context_unicode); 221 if (!break_iterator) { 222 return false; 223 } 224 int last_break_index = 0; 225 int break_index = 0; 226 int last_unicode_index = 0; 227 int unicode_index = 0; 228 auto token_begin_it = context_unicode.begin(); 229 while ((break_index = break_iterator->Next()) != 230 UniLib::BreakIterator::kDone) { 231 const int token_length = break_index - last_break_index; 232 unicode_index = last_unicode_index + token_length; 233 234 auto token_end_it = token_begin_it; 235 std::advance(token_end_it, token_length); 236 237 // Determine if the whole token is whitespace. 238 bool is_whitespace = true; 239 for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) { 240 if (!unilib_->IsWhitespace(*char_it)) { 241 is_whitespace = false; 242 break; 243 } 244 } 245 246 const std::string token = 247 context_unicode.UTF8Substring(token_begin_it, token_end_it); 248 249 if (!is_whitespace || icu_preserve_whitespace_tokens_) { 250 result->push_back(Token(token, last_unicode_index, unicode_index)); 251 } 252 253 last_break_index = break_index; 254 last_unicode_index = unicode_index; 255 token_begin_it = token_end_it; 256 } 257 258 return true; 259 } 260 261 } // namespace libtextclassifier3 262