1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "datetime/parser.h" 18 19 #include <set> 20 #include <unordered_set> 21 22 #include "datetime/extractor.h" 23 #include "util/calendar/calendar.h" 24 #include "util/i18n/locale.h" 25 #include "util/strings/split.h" 26 27 namespace libtextclassifier2 { 28 std::unique_ptr<DatetimeParser> DatetimeParser::Instance( 29 const DatetimeModel* model, const UniLib& unilib, 30 ZlibDecompressor* decompressor) { 31 std::unique_ptr<DatetimeParser> result( 32 new DatetimeParser(model, unilib, decompressor)); 33 if (!result->initialized_) { 34 result.reset(); 35 } 36 return result; 37 } 38 39 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib, 40 ZlibDecompressor* decompressor) 41 : unilib_(unilib) { 42 initialized_ = false; 43 44 if (model == nullptr) { 45 return; 46 } 47 48 if (model->patterns() != nullptr) { 49 for (const DatetimeModelPattern* pattern : *model->patterns()) { 50 if (pattern->regexes()) { 51 for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) { 52 std::unique_ptr<UniLib::RegexPattern> regex_pattern = 53 UncompressMakeRegexPattern(unilib, regex->pattern(), 54 regex->compressed_pattern(), 55 decompressor); 56 if (!regex_pattern) { 57 TC_LOG(ERROR) << "Couldn't create rule pattern."; 58 return; 59 } 60 rules_.push_back({std::move(regex_pattern), regex, pattern}); 61 if (pattern->locales()) { 62 for (int locale : *pattern->locales()) { 63 locale_to_rules_[locale].push_back(rules_.size() - 1); 64 } 65 } 66 } 67 } 68 } 69 } 70 71 if (model->extractors() != nullptr) { 72 for (const DatetimeModelExtractor* extractor : *model->extractors()) { 73 std::unique_ptr<UniLib::RegexPattern> regex_pattern = 74 UncompressMakeRegexPattern(unilib, extractor->pattern(), 75 extractor->compressed_pattern(), 76 decompressor); 77 if (!regex_pattern) { 78 TC_LOG(ERROR) << "Couldn't create extractor pattern"; 79 return; 80 } 81 extractor_rules_.push_back(std::move(regex_pattern)); 82 83 if (extractor->locales()) { 84 for (int locale : *extractor->locales()) { 85 type_and_locale_to_extractor_rule_[extractor->extractor()][locale] = 86 extractor_rules_.size() - 1; 87 } 88 } 89 } 90 } 91 92 if (model->locales() != nullptr) { 93 for (int i = 0; i < model->locales()->Length(); ++i) { 94 locale_string_to_id_[model->locales()->Get(i)->str()] = i; 95 } 96 } 97 98 if (model->default_locales() != nullptr) { 99 for (const int locale : *model->default_locales()) { 100 default_locale_ids_.push_back(locale); 101 } 102 } 103 104 use_extractors_for_locating_ = model->use_extractors_for_locating(); 105 106 initialized_ = true; 107 } 108 109 bool DatetimeParser::Parse( 110 const std::string& input, const int64 reference_time_ms_utc, 111 const std::string& reference_timezone, const std::string& locales, 112 ModeFlag mode, bool anchor_start_end, 113 std::vector<DatetimeParseResultSpan>* results) const { 114 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), 115 reference_time_ms_utc, reference_timezone, locales, mode, 116 anchor_start_end, results); 117 } 118 119 bool DatetimeParser::FindSpansUsingLocales( 120 const std::vector<int>& locale_ids, const UnicodeText& input, 121 const int64 reference_time_ms_utc, const std::string& reference_timezone, 122 ModeFlag mode, bool anchor_start_end, const std::string& reference_locale, 123 std::unordered_set<int>* executed_rules, 124 std::vector<DatetimeParseResultSpan>* found_spans) const { 125 for (const int locale_id : locale_ids) { 126 auto rules_it = locale_to_rules_.find(locale_id); 127 if (rules_it == locale_to_rules_.end()) { 128 continue; 129 } 130 131 for (const int rule_id : rules_it->second) { 132 // Skip rules that were already executed in previous locales. 133 if (executed_rules->find(rule_id) != executed_rules->end()) { 134 continue; 135 } 136 137 if (!(rules_[rule_id].pattern->enabled_modes() & mode)) { 138 continue; 139 } 140 141 executed_rules->insert(rule_id); 142 143 if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc, 144 reference_timezone, reference_locale, locale_id, 145 anchor_start_end, found_spans)) { 146 return false; 147 } 148 } 149 } 150 return true; 151 } 152 153 bool DatetimeParser::Parse( 154 const UnicodeText& input, const int64 reference_time_ms_utc, 155 const std::string& reference_timezone, const std::string& locales, 156 ModeFlag mode, bool anchor_start_end, 157 std::vector<DatetimeParseResultSpan>* results) const { 158 std::vector<DatetimeParseResultSpan> found_spans; 159 std::unordered_set<int> executed_rules; 160 std::string reference_locale; 161 const std::vector<int> requested_locales = 162 ParseAndExpandLocales(locales, &reference_locale); 163 if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc, 164 reference_timezone, mode, anchor_start_end, 165 reference_locale, &executed_rules, &found_spans)) { 166 return false; 167 } 168 169 std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans; 170 int counter = 0; 171 for (const auto& found_span : found_spans) { 172 indexed_found_spans.push_back({found_span, counter}); 173 counter++; 174 } 175 176 // Resolve conflicts by always picking the longer span and breaking ties by 177 // selecting the earlier entry in the list for a given locale. 178 std::sort(indexed_found_spans.begin(), indexed_found_spans.end(), 179 [](const std::pair<DatetimeParseResultSpan, int>& a, 180 const std::pair<DatetimeParseResultSpan, int>& b) { 181 if ((a.first.span.second - a.first.span.first) != 182 (b.first.span.second - b.first.span.first)) { 183 return (a.first.span.second - a.first.span.first) > 184 (b.first.span.second - b.first.span.first); 185 } else { 186 return a.second < b.second; 187 } 188 }); 189 190 found_spans.clear(); 191 for (auto& span_index_pair : indexed_found_spans) { 192 found_spans.push_back(span_index_pair.first); 193 } 194 195 std::set<int, std::function<bool(int, int)>> chosen_indices_set( 196 [&found_spans](int a, int b) { 197 return found_spans[a].span.first < found_spans[b].span.first; 198 }); 199 for (int i = 0; i < found_spans.size(); ++i) { 200 if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) { 201 chosen_indices_set.insert(i); 202 results->push_back(found_spans[i]); 203 } 204 } 205 206 return true; 207 } 208 209 bool DatetimeParser::HandleParseMatch( 210 const CompiledRule& rule, const UniLib::RegexMatcher& matcher, 211 int64 reference_time_ms_utc, const std::string& reference_timezone, 212 const std::string& reference_locale, int locale_id, 213 std::vector<DatetimeParseResultSpan>* result) const { 214 int status = UniLib::RegexMatcher::kNoError; 215 const int start = matcher.Start(&status); 216 if (status != UniLib::RegexMatcher::kNoError) { 217 return false; 218 } 219 220 const int end = matcher.End(&status); 221 if (status != UniLib::RegexMatcher::kNoError) { 222 return false; 223 } 224 225 DatetimeParseResultSpan parse_result; 226 if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone, 227 reference_locale, locale_id, &(parse_result.data), 228 &parse_result.span)) { 229 return false; 230 } 231 if (!use_extractors_for_locating_) { 232 parse_result.span = {start, end}; 233 } 234 if (parse_result.span.first != kInvalidIndex && 235 parse_result.span.second != kInvalidIndex) { 236 parse_result.target_classification_score = 237 rule.pattern->target_classification_score(); 238 parse_result.priority_score = rule.pattern->priority_score(); 239 result->push_back(parse_result); 240 } 241 return true; 242 } 243 244 bool DatetimeParser::ParseWithRule( 245 const CompiledRule& rule, const UnicodeText& input, 246 const int64 reference_time_ms_utc, const std::string& reference_timezone, 247 const std::string& reference_locale, const int locale_id, 248 bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const { 249 std::unique_ptr<UniLib::RegexMatcher> matcher = 250 rule.compiled_regex->Matcher(input); 251 int status = UniLib::RegexMatcher::kNoError; 252 if (anchor_start_end) { 253 if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) { 254 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, 255 reference_timezone, reference_locale, locale_id, 256 result)) { 257 return false; 258 } 259 } 260 } else { 261 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { 262 if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, 263 reference_timezone, reference_locale, locale_id, 264 result)) { 265 return false; 266 } 267 } 268 } 269 return true; 270 } 271 272 std::vector<int> DatetimeParser::ParseAndExpandLocales( 273 const std::string& locales, std::string* reference_locale) const { 274 std::vector<StringPiece> split_locales = strings::Split(locales, ','); 275 if (!split_locales.empty()) { 276 *reference_locale = split_locales[0].ToString(); 277 } else { 278 *reference_locale = ""; 279 } 280 281 std::vector<int> result; 282 for (const StringPiece& locale_str : split_locales) { 283 auto locale_it = locale_string_to_id_.find(locale_str.ToString()); 284 if (locale_it != locale_string_to_id_.end()) { 285 result.push_back(locale_it->second); 286 } 287 288 const Locale locale = Locale::FromBCP47(locale_str.ToString()); 289 if (!locale.IsValid()) { 290 continue; 291 } 292 293 const std::string language = locale.Language(); 294 const std::string script = locale.Script(); 295 const std::string region = locale.Region(); 296 297 // First, try adding *-region locale. 298 if (!region.empty()) { 299 locale_it = locale_string_to_id_.find("*-" + region); 300 if (locale_it != locale_string_to_id_.end()) { 301 result.push_back(locale_it->second); 302 } 303 } 304 // Second, try adding language-script-* locale. 305 if (!script.empty()) { 306 locale_it = locale_string_to_id_.find(language + "-" + script + "-*"); 307 if (locale_it != locale_string_to_id_.end()) { 308 result.push_back(locale_it->second); 309 } 310 } 311 // Third, try adding language-* locale. 312 if (!language.empty()) { 313 locale_it = locale_string_to_id_.find(language + "-*"); 314 if (locale_it != locale_string_to_id_.end()) { 315 result.push_back(locale_it->second); 316 } 317 } 318 } 319 320 // Add the default locales if they haven't been added already. 321 const std::unordered_set<int> result_set(result.begin(), result.end()); 322 for (const int default_locale_id : default_locale_ids_) { 323 if (result_set.find(default_locale_id) == result_set.end()) { 324 result.push_back(default_locale_id); 325 } 326 } 327 328 return result; 329 } 330 331 namespace { 332 333 DatetimeGranularity GetGranularity(const DateParseData& data) { 334 DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR; 335 if ((data.field_set_mask & DateParseData::YEAR_FIELD) || 336 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && 337 (data.relation_type == DateParseData::RelationType::YEAR))) { 338 granularity = DatetimeGranularity::GRANULARITY_YEAR; 339 } 340 if ((data.field_set_mask & DateParseData::MONTH_FIELD) || 341 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && 342 (data.relation_type == DateParseData::RelationType::MONTH))) { 343 granularity = DatetimeGranularity::GRANULARITY_MONTH; 344 } 345 if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && 346 (data.relation_type == DateParseData::RelationType::WEEK)) { 347 granularity = DatetimeGranularity::GRANULARITY_WEEK; 348 } 349 if (data.field_set_mask & DateParseData::DAY_FIELD || 350 (data.field_set_mask & DateParseData::RELATION_FIELD && 351 (data.relation == DateParseData::Relation::NOW || 352 data.relation == DateParseData::Relation::TOMORROW || 353 data.relation == DateParseData::Relation::YESTERDAY)) || 354 (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD && 355 (data.relation_type == DateParseData::RelationType::MONDAY || 356 data.relation_type == DateParseData::RelationType::TUESDAY || 357 data.relation_type == DateParseData::RelationType::WEDNESDAY || 358 data.relation_type == DateParseData::RelationType::THURSDAY || 359 data.relation_type == DateParseData::RelationType::FRIDAY || 360 data.relation_type == DateParseData::RelationType::SATURDAY || 361 data.relation_type == DateParseData::RelationType::SUNDAY || 362 data.relation_type == DateParseData::RelationType::DAY))) { 363 granularity = DatetimeGranularity::GRANULARITY_DAY; 364 } 365 if (data.field_set_mask & DateParseData::HOUR_FIELD) { 366 granularity = DatetimeGranularity::GRANULARITY_HOUR; 367 } 368 if (data.field_set_mask & DateParseData::MINUTE_FIELD) { 369 granularity = DatetimeGranularity::GRANULARITY_MINUTE; 370 } 371 if (data.field_set_mask & DateParseData::SECOND_FIELD) { 372 granularity = DatetimeGranularity::GRANULARITY_SECOND; 373 } 374 return granularity; 375 } 376 377 } // namespace 378 379 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, 380 const UniLib::RegexMatcher& matcher, 381 const int64 reference_time_ms_utc, 382 const std::string& reference_timezone, 383 const std::string& reference_locale, 384 int locale_id, DatetimeParseResult* result, 385 CodepointSpan* result_span) const { 386 DateParseData parse; 387 DatetimeExtractor extractor(rule, matcher, locale_id, unilib_, 388 extractor_rules_, 389 type_and_locale_to_extractor_rule_); 390 if (!extractor.Extract(&parse, result_span)) { 391 return false; 392 } 393 394 result->granularity = GetGranularity(parse); 395 396 if (!calendar_lib_.InterpretParseData( 397 parse, reference_time_ms_utc, reference_timezone, reference_locale, 398 result->granularity, &(result->time_ms_utc))) { 399 return false; 400 } 401 402 return true; 403 } 404 405 } // namespace libtextclassifier2 406