Home | History | Annotate | Download | only in annotator
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
     18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
     19 
     20 #include <time.h>
     21 #include <algorithm>
     22 #include <cmath>
     23 #include <functional>
     24 #include <map>
     25 #include <set>
     26 #include <string>
     27 #include <utility>
     28 #include <vector>
     29 
     30 #include "annotator/entity-data_generated.h"
     31 #include "utils/base/integral_types.h"
     32 #include "utils/base/logging.h"
     33 #include "utils/flatbuffers.h"
     34 #include "utils/variant.h"
     35 
     36 namespace libtextclassifier3 {
     37 
     38 constexpr int kInvalidIndex = -1;
     39 
     40 // Index for a 0-based array of tokens.
     41 using TokenIndex = int;
     42 
     43 // Index for a 0-based array of codepoints.
     44 using CodepointIndex = int;
     45 
     46 // Marks a span in a sequence of codepoints. The first element is the index of
     47 // the first codepoint of the span, and the second element is the index of the
     48 // codepoint one past the end of the span.
     49 // TODO(b/71982294): Make it a struct.
     50 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
     51 
     52 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
     53   return a.first < b.second && b.first < a.second;
     54 }
     55 
     56 inline bool ValidNonEmptySpan(const CodepointSpan& span) {
     57   return span.first < span.second && span.first >= 0 && span.second >= 0;
     58 }
     59 
     60 template <typename T>
     61 bool DoesCandidateConflict(
     62     const int considered_candidate, const std::vector<T>& candidates,
     63     const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
     64   if (chosen_indices_set.empty()) {
     65     return false;
     66   }
     67 
     68   auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
     69   // Check conflict on the right.
     70   if (conflicting_it != chosen_indices_set.end() &&
     71       SpansOverlap(candidates[considered_candidate].span,
     72                    candidates[*conflicting_it].span)) {
     73     return true;
     74   }
     75 
     76   // Check conflict on the left.
     77   // If we can't go more left, there can't be a conflict:
     78   if (conflicting_it == chosen_indices_set.begin()) {
     79     return false;
     80   }
     81   // Otherwise move one span left and insert if it doesn't overlap with the
     82   // candidate.
     83   --conflicting_it;
     84   if (!SpansOverlap(candidates[considered_candidate].span,
     85                     candidates[*conflicting_it].span)) {
     86     return false;
     87   }
     88 
     89   return true;
     90 }
     91 
     92 // Marks a span in a sequence of tokens. The first element is the index of the
     93 // first token in the span, and the second element is the index of the token one
     94 // past the end of the span.
     95 // TODO(b/71982294): Make it a struct.
     96 using TokenSpan = std::pair<TokenIndex, TokenIndex>;
     97 
     98 // Returns the size of the token span. Assumes that the span is valid.
     99 inline int TokenSpanSize(const TokenSpan& token_span) {
    100   return token_span.second - token_span.first;
    101 }
    102 
    103 // Returns a token span consisting of one token.
    104 inline TokenSpan SingleTokenSpan(int token_index) {
    105   return {token_index, token_index + 1};
    106 }
    107 
    108 // Returns an intersection of two token spans. Assumes that both spans are valid
    109 // and overlapping.
    110 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
    111                                      const TokenSpan& token_span2) {
    112   return {std::max(token_span1.first, token_span2.first),
    113           std::min(token_span1.second, token_span2.second)};
    114 }
    115 
    116 // Returns and expanded token span by adding a certain number of tokens on its
    117 // left and on its right.
    118 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
    119                                  int num_tokens_left, int num_tokens_right) {
    120   return {token_span.first - num_tokens_left,
    121           token_span.second + num_tokens_right};
    122 }
    123 
    124 // Token holds a token, its position in the original string and whether it was
    125 // part of the input span.
    126 struct Token {
    127   std::string value;
    128   CodepointIndex start;
    129   CodepointIndex end;
    130 
    131   // Whether the token is a padding token.
    132   bool is_padding;
    133 
    134   // Default constructor constructs the padding-token.
    135   Token()
    136       : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
    137 
    138   Token(const std::string& arg_value, CodepointIndex arg_start,
    139         CodepointIndex arg_end)
    140       : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
    141 
    142   bool operator==(const Token& other) const {
    143     return value == other.value && start == other.start && end == other.end &&
    144            is_padding == other.is_padding;
    145   }
    146 
    147   bool IsContainedInSpan(CodepointSpan span) const {
    148     return start >= span.first && end <= span.second;
    149   }
    150 };
    151 
    152 // Pretty-printing function for Token.
    153 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
    154                                          const Token& token);
    155 
    156 enum DatetimeGranularity {
    157   GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
    158                              // structure being uninitialized.
    159   GRANULARITY_YEAR = 0,
    160   GRANULARITY_MONTH = 1,
    161   GRANULARITY_WEEK = 2,
    162   GRANULARITY_DAY = 3,
    163   GRANULARITY_HOUR = 4,
    164   GRANULARITY_MINUTE = 5,
    165   GRANULARITY_SECOND = 6
    166 };
    167 
    168 struct DatetimeParseResult {
    169   // The absolute time in milliseconds since the epoch in UTC.
    170   int64 time_ms_utc;
    171 
    172   // The precision of the estimate then in to calculating the milliseconds
    173   DatetimeGranularity granularity;
    174 
    175   DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
    176 
    177   DatetimeParseResult(int64 arg_time_ms_utc,
    178                       DatetimeGranularity arg_granularity)
    179       : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
    180 
    181   bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
    182 
    183   bool operator==(const DatetimeParseResult& other) const {
    184     return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
    185   }
    186 };
    187 
    188 const float kFloatCompareEpsilon = 1e-5;
    189 
    190 struct DatetimeParseResultSpan {
    191   CodepointSpan span;
    192   std::vector<DatetimeParseResult> data;
    193   float target_classification_score;
    194   float priority_score;
    195 
    196   bool operator==(const DatetimeParseResultSpan& other) const {
    197     return span == other.span && data == other.data &&
    198            std::abs(target_classification_score -
    199                     other.target_classification_score) < kFloatCompareEpsilon &&
    200            std::abs(priority_score - other.priority_score) <
    201                kFloatCompareEpsilon;
    202   }
    203 };
    204 
    205 // Pretty-printing function for DatetimeParseResultSpan.
    206 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
    207                                          const DatetimeParseResultSpan& value);
    208 
    209 struct ClassificationResult {
    210   std::string collection;
    211   float score;
    212   DatetimeParseResult datetime_parse_result;
    213   std::string serialized_knowledge_result;
    214   std::string contact_name, contact_given_name, contact_nickname,
    215       contact_email_address, contact_phone_number, contact_id;
    216   std::string app_name, app_package_name;
    217   int64 numeric_value;
    218 
    219   // Length of the parsed duration in milliseconds.
    220   int64 duration_ms;
    221 
    222   // Internal score used for conflict resolution.
    223   float priority_score;
    224 
    225 
    226   // Entity data information.
    227   std::string serialized_entity_data;
    228   const EntityData* entity_data() {
    229     return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
    230                                                serialized_entity_data.size());
    231   }
    232 
    233   explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
    234 
    235   ClassificationResult(const std::string& arg_collection, float arg_score)
    236       : collection(arg_collection),
    237         score(arg_score),
    238         priority_score(arg_score) {}
    239 
    240   ClassificationResult(const std::string& arg_collection, float arg_score,
    241                        float arg_priority_score)
    242       : collection(arg_collection),
    243         score(arg_score),
    244         priority_score(arg_priority_score) {}
    245 };
    246 
    247 // Pretty-printing function for ClassificationResult.
    248 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
    249                                          const ClassificationResult& result);
    250 
    251 // Pretty-printing function for std::vector<ClassificationResult>.
    252 logging::LoggingStringStream& operator<<(
    253     logging::LoggingStringStream& stream,
    254     const std::vector<ClassificationResult>& results);
    255 
    256 // Represents a result of Annotate call.
    257 struct AnnotatedSpan {
    258   enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME };
    259 
    260   // Unicode codepoint indices in the input string.
    261   CodepointSpan span = {kInvalidIndex, kInvalidIndex};
    262 
    263   // Classification result for the span.
    264   std::vector<ClassificationResult> classification;
    265 
    266   // The source of the annotation, used in conflict resolution.
    267   Source source = Source::OTHER;
    268 
    269   AnnotatedSpan() = default;
    270 
    271   AnnotatedSpan(CodepointSpan arg_span,
    272                 std::vector<ClassificationResult> arg_classification)
    273       : span(arg_span), classification(std::move(arg_classification)) {}
    274 };
    275 
    276 // Pretty-printing function for AnnotatedSpan.
    277 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
    278                                          const AnnotatedSpan& span);
    279 
    280 // StringPiece analogue for std::vector<T>.
    281 template <class T>
    282 class VectorSpan {
    283  public:
    284   VectorSpan() : begin_(), end_() {}
    285   VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
    286       : begin_(v.begin()), end_(v.end()) {}
    287   VectorSpan(typename std::vector<T>::const_iterator begin,
    288              typename std::vector<T>::const_iterator end)
    289       : begin_(begin), end_(end) {}
    290 
    291   const T& operator[](typename std::vector<T>::size_type i) const {
    292     return *(begin_ + i);
    293   }
    294 
    295   int size() const { return end_ - begin_; }
    296   typename std::vector<T>::const_iterator begin() const { return begin_; }
    297   typename std::vector<T>::const_iterator end() const { return end_; }
    298   const float* data() const { return &(*begin_); }
    299 
    300  private:
    301   typename std::vector<T>::const_iterator begin_;
    302   typename std::vector<T>::const_iterator end_;
    303 };
    304 
    305 struct DateParseData {
    306   enum class Relation {
    307     UNSPECIFIED = 0,
    308     NEXT = 1,
    309     NEXT_OR_SAME = 2,
    310     LAST = 3,
    311     NOW = 4,
    312     TOMORROW = 5,
    313     YESTERDAY = 6,
    314     PAST = 7,
    315     FUTURE = 8
    316   };
    317 
    318   enum class RelationType {
    319     UNSPECIFIED = 0,
    320     SUNDAY = 1,
    321     MONDAY = 2,
    322     TUESDAY = 3,
    323     WEDNESDAY = 4,
    324     THURSDAY = 5,
    325     FRIDAY = 6,
    326     SATURDAY = 7,
    327     DAY = 8,
    328     WEEK = 9,
    329     MONTH = 10,
    330     YEAR = 11,
    331     HOUR = 12,
    332     MINUTE = 13,
    333     SECOND = 14,
    334   };
    335 
    336   enum Fields {
    337     YEAR_FIELD = 1 << 0,
    338     MONTH_FIELD = 1 << 1,
    339     DAY_FIELD = 1 << 2,
    340     HOUR_FIELD = 1 << 3,
    341     MINUTE_FIELD = 1 << 4,
    342     SECOND_FIELD = 1 << 5,
    343     AMPM_FIELD = 1 << 6,
    344     ZONE_OFFSET_FIELD = 1 << 7,
    345     DST_OFFSET_FIELD = 1 << 8,
    346     RELATION_FIELD = 1 << 9,
    347     RELATION_TYPE_FIELD = 1 << 10,
    348     RELATION_DISTANCE_FIELD = 1 << 11
    349   };
    350 
    351   enum class AMPM { AM = 0, PM = 1 };
    352 
    353   enum class TimeUnit {
    354     DAYS = 1,
    355     WEEKS = 2,
    356     MONTHS = 3,
    357     HOURS = 4,
    358     MINUTES = 5,
    359     SECONDS = 6,
    360     YEARS = 7
    361   };
    362 
    363   // Bit mask of fields which have been set on the struct
    364   int field_set_mask = 0;
    365 
    366   // Fields describing absolute date fields.
    367   // Year of the date seen in the text match.
    368   int year = 0;
    369   // Month of the year starting with January = 1.
    370   int month = 0;
    371   // Day of the month starting with 1.
    372   int day_of_month = 0;
    373   // Hour of the day with a range of 0-23,
    374   // values less than 12 need the AMPM field below or heuristics
    375   // to definitively determine the time.
    376   int hour = 0;
    377   // Hour of the day with a range of 0-59.
    378   int minute = 0;
    379   // Hour of the day with a range of 0-59.
    380   int second = 0;
    381   // 0 == AM, 1 == PM
    382   AMPM ampm = AMPM::AM;
    383   // Number of hours offset from UTC this date time is in.
    384   int zone_offset = 0;
    385   // Number of hours offest for DST
    386   int dst_offset = 0;
    387 
    388   // The permutation from now that was made to find the date time.
    389   Relation relation = Relation::UNSPECIFIED;
    390   // The unit of measure of the change to the date time.
    391   RelationType relation_type = RelationType::UNSPECIFIED;
    392   // The number of units of change that were made.
    393   int relation_distance = 0;
    394 
    395   DateParseData() = default;
    396 
    397   DateParseData(int field_set_mask, int year, int month, int day_of_month,
    398                 int hour, int minute, int second, AMPM ampm, int zone_offset,
    399                 int dst_offset, Relation relation, RelationType relation_type,
    400                 int relation_distance) {
    401     this->field_set_mask = field_set_mask;
    402     this->year = year;
    403     this->month = month;
    404     this->day_of_month = day_of_month;
    405     this->hour = hour;
    406     this->minute = minute;
    407     this->second = second;
    408     this->ampm = ampm;
    409     this->zone_offset = zone_offset;
    410     this->dst_offset = dst_offset;
    411     this->relation = relation;
    412     this->relation_type = relation_type;
    413     this->relation_distance = relation_distance;
    414   }
    415 };
    416 
    417 // Pretty-printing function for DateParseData.
    418 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
    419                                          const DateParseData& data);
    420 
    421 }  // namespace libtextclassifier3
    422 
    423 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
    424