Home | History | Annotate | Download | only in datetime
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "datetime/parser.h"
     18 
     19 #include <set>
     20 #include <unordered_set>
     21 
     22 #include "datetime/extractor.h"
     23 #include "util/calendar/calendar.h"
     24 #include "util/i18n/locale.h"
     25 #include "util/strings/split.h"
     26 
     27 namespace libtextclassifier2 {
     28 std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
     29     const DatetimeModel* model, const UniLib& unilib,
     30     ZlibDecompressor* decompressor) {
     31   std::unique_ptr<DatetimeParser> result(
     32       new DatetimeParser(model, unilib, decompressor));
     33   if (!result->initialized_) {
     34     result.reset();
     35   }
     36   return result;
     37 }
     38 
     39 DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
     40                                ZlibDecompressor* decompressor)
     41     : unilib_(unilib) {
     42   initialized_ = false;
     43 
     44   if (model == nullptr) {
     45     return;
     46   }
     47 
     48   if (model->patterns() != nullptr) {
     49     for (const DatetimeModelPattern* pattern : *model->patterns()) {
     50       if (pattern->regexes()) {
     51         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
     52           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
     53               UncompressMakeRegexPattern(unilib, regex->pattern(),
     54                                          regex->compressed_pattern(),
     55                                          decompressor);
     56           if (!regex_pattern) {
     57             TC_LOG(ERROR) << "Couldn't create rule pattern.";
     58             return;
     59           }
     60           rules_.push_back({std::move(regex_pattern), regex, pattern});
     61           if (pattern->locales()) {
     62             for (int locale : *pattern->locales()) {
     63               locale_to_rules_[locale].push_back(rules_.size() - 1);
     64             }
     65           }
     66         }
     67       }
     68     }
     69   }
     70 
     71   if (model->extractors() != nullptr) {
     72     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
     73       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
     74           UncompressMakeRegexPattern(unilib, extractor->pattern(),
     75                                      extractor->compressed_pattern(),
     76                                      decompressor);
     77       if (!regex_pattern) {
     78         TC_LOG(ERROR) << "Couldn't create extractor pattern";
     79         return;
     80       }
     81       extractor_rules_.push_back(std::move(regex_pattern));
     82 
     83       if (extractor->locales()) {
     84         for (int locale : *extractor->locales()) {
     85           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
     86               extractor_rules_.size() - 1;
     87         }
     88       }
     89     }
     90   }
     91 
     92   if (model->locales() != nullptr) {
     93     for (int i = 0; i < model->locales()->Length(); ++i) {
     94       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
     95     }
     96   }
     97 
     98   if (model->default_locales() != nullptr) {
     99     for (const int locale : *model->default_locales()) {
    100       default_locale_ids_.push_back(locale);
    101     }
    102   }
    103 
    104   use_extractors_for_locating_ = model->use_extractors_for_locating();
    105 
    106   initialized_ = true;
    107 }
    108 
    109 bool DatetimeParser::Parse(
    110     const std::string& input, const int64 reference_time_ms_utc,
    111     const std::string& reference_timezone, const std::string& locales,
    112     ModeFlag mode, bool anchor_start_end,
    113     std::vector<DatetimeParseResultSpan>* results) const {
    114   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
    115                reference_time_ms_utc, reference_timezone, locales, mode,
    116                anchor_start_end, results);
    117 }
    118 
    119 bool DatetimeParser::FindSpansUsingLocales(
    120     const std::vector<int>& locale_ids, const UnicodeText& input,
    121     const int64 reference_time_ms_utc, const std::string& reference_timezone,
    122     ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
    123     std::unordered_set<int>* executed_rules,
    124     std::vector<DatetimeParseResultSpan>* found_spans) const {
    125   for (const int locale_id : locale_ids) {
    126     auto rules_it = locale_to_rules_.find(locale_id);
    127     if (rules_it == locale_to_rules_.end()) {
    128       continue;
    129     }
    130 
    131     for (const int rule_id : rules_it->second) {
    132       // Skip rules that were already executed in previous locales.
    133       if (executed_rules->find(rule_id) != executed_rules->end()) {
    134         continue;
    135       }
    136 
    137       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
    138         continue;
    139       }
    140 
    141       executed_rules->insert(rule_id);
    142 
    143       if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
    144                          reference_timezone, reference_locale, locale_id,
    145                          anchor_start_end, found_spans)) {
    146         return false;
    147       }
    148     }
    149   }
    150   return true;
    151 }
    152 
    153 bool DatetimeParser::Parse(
    154     const UnicodeText& input, const int64 reference_time_ms_utc,
    155     const std::string& reference_timezone, const std::string& locales,
    156     ModeFlag mode, bool anchor_start_end,
    157     std::vector<DatetimeParseResultSpan>* results) const {
    158   std::vector<DatetimeParseResultSpan> found_spans;
    159   std::unordered_set<int> executed_rules;
    160   std::string reference_locale;
    161   const std::vector<int> requested_locales =
    162       ParseAndExpandLocales(locales, &reference_locale);
    163   if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
    164                              reference_timezone, mode, anchor_start_end,
    165                              reference_locale, &executed_rules, &found_spans)) {
    166     return false;
    167   }
    168 
    169   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
    170   int counter = 0;
    171   for (const auto& found_span : found_spans) {
    172     indexed_found_spans.push_back({found_span, counter});
    173     counter++;
    174   }
    175 
    176   // Resolve conflicts by always picking the longer span and breaking ties by
    177   // selecting the earlier entry in the list for a given locale.
    178   std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
    179             [](const std::pair<DatetimeParseResultSpan, int>& a,
    180                const std::pair<DatetimeParseResultSpan, int>& b) {
    181               if ((a.first.span.second - a.first.span.first) !=
    182                   (b.first.span.second - b.first.span.first)) {
    183                 return (a.first.span.second - a.first.span.first) >
    184                        (b.first.span.second - b.first.span.first);
    185               } else {
    186                 return a.second < b.second;
    187               }
    188             });
    189 
    190   found_spans.clear();
    191   for (auto& span_index_pair : indexed_found_spans) {
    192     found_spans.push_back(span_index_pair.first);
    193   }
    194 
    195   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
    196       [&found_spans](int a, int b) {
    197         return found_spans[a].span.first < found_spans[b].span.first;
    198       });
    199   for (int i = 0; i < found_spans.size(); ++i) {
    200     if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
    201       chosen_indices_set.insert(i);
    202       results->push_back(found_spans[i]);
    203     }
    204   }
    205 
    206   return true;
    207 }
    208 
    209 bool DatetimeParser::HandleParseMatch(
    210     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
    211     int64 reference_time_ms_utc, const std::string& reference_timezone,
    212     const std::string& reference_locale, int locale_id,
    213     std::vector<DatetimeParseResultSpan>* result) const {
    214   int status = UniLib::RegexMatcher::kNoError;
    215   const int start = matcher.Start(&status);
    216   if (status != UniLib::RegexMatcher::kNoError) {
    217     return false;
    218   }
    219 
    220   const int end = matcher.End(&status);
    221   if (status != UniLib::RegexMatcher::kNoError) {
    222     return false;
    223   }
    224 
    225   DatetimeParseResultSpan parse_result;
    226   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
    227                        reference_locale, locale_id, &(parse_result.data),
    228                        &parse_result.span)) {
    229     return false;
    230   }
    231   if (!use_extractors_for_locating_) {
    232     parse_result.span = {start, end};
    233   }
    234   if (parse_result.span.first != kInvalidIndex &&
    235       parse_result.span.second != kInvalidIndex) {
    236     parse_result.target_classification_score =
    237         rule.pattern->target_classification_score();
    238     parse_result.priority_score = rule.pattern->priority_score();
    239     result->push_back(parse_result);
    240   }
    241   return true;
    242 }
    243 
    244 bool DatetimeParser::ParseWithRule(
    245     const CompiledRule& rule, const UnicodeText& input,
    246     const int64 reference_time_ms_utc, const std::string& reference_timezone,
    247     const std::string& reference_locale, const int locale_id,
    248     bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
    249   std::unique_ptr<UniLib::RegexMatcher> matcher =
    250       rule.compiled_regex->Matcher(input);
    251   int status = UniLib::RegexMatcher::kNoError;
    252   if (anchor_start_end) {
    253     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
    254       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
    255                             reference_timezone, reference_locale, locale_id,
    256                             result)) {
    257         return false;
    258       }
    259     }
    260   } else {
    261     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
    262       if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
    263                             reference_timezone, reference_locale, locale_id,
    264                             result)) {
    265         return false;
    266       }
    267     }
    268   }
    269   return true;
    270 }
    271 
    272 std::vector<int> DatetimeParser::ParseAndExpandLocales(
    273     const std::string& locales, std::string* reference_locale) const {
    274   std::vector<StringPiece> split_locales = strings::Split(locales, ',');
    275   if (!split_locales.empty()) {
    276     *reference_locale = split_locales[0].ToString();
    277   } else {
    278     *reference_locale = "";
    279   }
    280 
    281   std::vector<int> result;
    282   for (const StringPiece& locale_str : split_locales) {
    283     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
    284     if (locale_it != locale_string_to_id_.end()) {
    285       result.push_back(locale_it->second);
    286     }
    287 
    288     const Locale locale = Locale::FromBCP47(locale_str.ToString());
    289     if (!locale.IsValid()) {
    290       continue;
    291     }
    292 
    293     const std::string language = locale.Language();
    294     const std::string script = locale.Script();
    295     const std::string region = locale.Region();
    296 
    297     // First, try adding *-region locale.
    298     if (!region.empty()) {
    299       locale_it = locale_string_to_id_.find("*-" + region);
    300       if (locale_it != locale_string_to_id_.end()) {
    301         result.push_back(locale_it->second);
    302       }
    303     }
    304     // Second, try adding language-script-* locale.
    305     if (!script.empty()) {
    306       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
    307       if (locale_it != locale_string_to_id_.end()) {
    308         result.push_back(locale_it->second);
    309       }
    310     }
    311     // Third, try adding language-* locale.
    312     if (!language.empty()) {
    313       locale_it = locale_string_to_id_.find(language + "-*");
    314       if (locale_it != locale_string_to_id_.end()) {
    315         result.push_back(locale_it->second);
    316       }
    317     }
    318   }
    319 
    320   // Add the default locales if they haven't been added already.
    321   const std::unordered_set<int> result_set(result.begin(), result.end());
    322   for (const int default_locale_id : default_locale_ids_) {
    323     if (result_set.find(default_locale_id) == result_set.end()) {
    324       result.push_back(default_locale_id);
    325     }
    326   }
    327 
    328   return result;
    329 }
    330 
    331 namespace {
    332 
    333 DatetimeGranularity GetGranularity(const DateParseData& data) {
    334   DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
    335   if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
    336       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
    337        (data.relation_type == DateParseData::RelationType::YEAR))) {
    338     granularity = DatetimeGranularity::GRANULARITY_YEAR;
    339   }
    340   if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
    341       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
    342        (data.relation_type == DateParseData::RelationType::MONTH))) {
    343     granularity = DatetimeGranularity::GRANULARITY_MONTH;
    344   }
    345   if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
    346       (data.relation_type == DateParseData::RelationType::WEEK)) {
    347     granularity = DatetimeGranularity::GRANULARITY_WEEK;
    348   }
    349   if (data.field_set_mask & DateParseData::DAY_FIELD ||
    350       (data.field_set_mask & DateParseData::RELATION_FIELD &&
    351        (data.relation == DateParseData::Relation::NOW ||
    352         data.relation == DateParseData::Relation::TOMORROW ||
    353         data.relation == DateParseData::Relation::YESTERDAY)) ||
    354       (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
    355        (data.relation_type == DateParseData::RelationType::MONDAY ||
    356         data.relation_type == DateParseData::RelationType::TUESDAY ||
    357         data.relation_type == DateParseData::RelationType::WEDNESDAY ||
    358         data.relation_type == DateParseData::RelationType::THURSDAY ||
    359         data.relation_type == DateParseData::RelationType::FRIDAY ||
    360         data.relation_type == DateParseData::RelationType::SATURDAY ||
    361         data.relation_type == DateParseData::RelationType::SUNDAY ||
    362         data.relation_type == DateParseData::RelationType::DAY))) {
    363     granularity = DatetimeGranularity::GRANULARITY_DAY;
    364   }
    365   if (data.field_set_mask & DateParseData::HOUR_FIELD) {
    366     granularity = DatetimeGranularity::GRANULARITY_HOUR;
    367   }
    368   if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
    369     granularity = DatetimeGranularity::GRANULARITY_MINUTE;
    370   }
    371   if (data.field_set_mask & DateParseData::SECOND_FIELD) {
    372     granularity = DatetimeGranularity::GRANULARITY_SECOND;
    373   }
    374   return granularity;
    375 }
    376 
    377 }  // namespace
    378 
    379 bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
    380                                      const UniLib::RegexMatcher& matcher,
    381                                      const int64 reference_time_ms_utc,
    382                                      const std::string& reference_timezone,
    383                                      const std::string& reference_locale,
    384                                      int locale_id, DatetimeParseResult* result,
    385                                      CodepointSpan* result_span) const {
    386   DateParseData parse;
    387   DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
    388                               extractor_rules_,
    389                               type_and_locale_to_extractor_rule_);
    390   if (!extractor.Extract(&parse, result_span)) {
    391     return false;
    392   }
    393 
    394   result->granularity = GetGranularity(parse);
    395 
    396   if (!calendar_lib_.InterpretParseData(
    397           parse, reference_time_ms_utc, reference_timezone, reference_locale,
    398           result->granularity, &(result->time_ms_utc))) {
    399     return false;
    400   }
    401 
    402   return true;
    403 }
    404 
    405 }  // namespace libtextclassifier2
    406