Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_TYPES_H_
     18 #define LIBTEXTCLASSIFIER_TYPES_H_
     19 
     20 #include <algorithm>
     21 #include <cmath>
     22 #include <functional>
     23 #include <set>
     24 #include <string>
     25 #include <utility>
     26 #include <vector>
     27 #include "util/base/integral_types.h"
     28 
     29 #include "util/base/logging.h"
     30 
     31 namespace libtextclassifier2 {
     32 
     33 constexpr int kInvalidIndex = -1;
     34 
     35 // Index for a 0-based array of tokens.
     36 using TokenIndex = int;
     37 
     38 // Index for a 0-based array of codepoints.
     39 using CodepointIndex = int;
     40 
     41 // Marks a span in a sequence of codepoints. The first element is the index of
     42 // the first codepoint of the span, and the second element is the index of the
     43 // codepoint one past the end of the span.
     44 // TODO(b/71982294): Make it a struct.
     45 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
     46 
     47 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
     48   return a.first < b.second && b.first < a.second;
     49 }
     50 
     51 inline bool ValidNonEmptySpan(const CodepointSpan& span) {
     52   return span.first < span.second && span.first >= 0 && span.second >= 0;
     53 }
     54 
     55 template <typename T>
     56 bool DoesCandidateConflict(
     57     const int considered_candidate, const std::vector<T>& candidates,
     58     const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
     59   if (chosen_indices_set.empty()) {
     60     return false;
     61   }
     62 
     63   auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
     64   // Check conflict on the right.
     65   if (conflicting_it != chosen_indices_set.end() &&
     66       SpansOverlap(candidates[considered_candidate].span,
     67                    candidates[*conflicting_it].span)) {
     68     return true;
     69   }
     70 
     71   // Check conflict on the left.
     72   // If we can't go more left, there can't be a conflict:
     73   if (conflicting_it == chosen_indices_set.begin()) {
     74     return false;
     75   }
     76   // Otherwise move one span left and insert if it doesn't overlap with the
     77   // candidate.
     78   --conflicting_it;
     79   if (!SpansOverlap(candidates[considered_candidate].span,
     80                     candidates[*conflicting_it].span)) {
     81     return false;
     82   }
     83 
     84   return true;
     85 }
     86 
     87 // Marks a span in a sequence of tokens. The first element is the index of the
     88 // first token in the span, and the second element is the index of the token one
     89 // past the end of the span.
     90 // TODO(b/71982294): Make it a struct.
     91 using TokenSpan = std::pair<TokenIndex, TokenIndex>;
     92 
     93 // Returns the size of the token span. Assumes that the span is valid.
     94 inline int TokenSpanSize(const TokenSpan& token_span) {
     95   return token_span.second - token_span.first;
     96 }
     97 
     98 // Returns a token span consisting of one token.
     99 inline TokenSpan SingleTokenSpan(int token_index) {
    100   return {token_index, token_index + 1};
    101 }
    102 
    103 // Returns an intersection of two token spans. Assumes that both spans are valid
    104 // and overlapping.
    105 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
    106                                      const TokenSpan& token_span2) {
    107   return {std::max(token_span1.first, token_span2.first),
    108           std::min(token_span1.second, token_span2.second)};
    109 }
    110 
    111 // Returns and expanded token span by adding a certain number of tokens on its
    112 // left and on its right.
    113 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
    114                                  int num_tokens_left, int num_tokens_right) {
    115   return {token_span.first - num_tokens_left,
    116           token_span.second + num_tokens_right};
    117 }
    118 
    119 // Token holds a token, its position in the original string and whether it was
    120 // part of the input span.
    121 struct Token {
    122   std::string value;
    123   CodepointIndex start;
    124   CodepointIndex end;
    125 
    126   // Whether the token is a padding token.
    127   bool is_padding;
    128 
    129   // Default constructor constructs the padding-token.
    130   Token()
    131       : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
    132 
    133   Token(const std::string& arg_value, CodepointIndex arg_start,
    134         CodepointIndex arg_end)
    135       : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
    136 
    137   bool operator==(const Token& other) const {
    138     return value == other.value && start == other.start && end == other.end &&
    139            is_padding == other.is_padding;
    140   }
    141 
    142   bool IsContainedInSpan(CodepointSpan span) const {
    143     return start >= span.first && end <= span.second;
    144   }
    145 };
    146 
    147 // Pretty-printing function for Token.
    148 inline logging::LoggingStringStream& operator<<(
    149     logging::LoggingStringStream& stream, const Token& token) {
    150   if (!token.is_padding) {
    151     return stream << "Token(\"" << token.value << "\", " << token.start << ", "
    152                   << token.end << ")";
    153   } else {
    154     return stream << "Token()";
    155   }
    156 }
    157 
    158 enum DatetimeGranularity {
    159   GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
    160                              // structure being uninitialized.
    161   GRANULARITY_YEAR = 0,
    162   GRANULARITY_MONTH = 1,
    163   GRANULARITY_WEEK = 2,
    164   GRANULARITY_DAY = 3,
    165   GRANULARITY_HOUR = 4,
    166   GRANULARITY_MINUTE = 5,
    167   GRANULARITY_SECOND = 6
    168 };
    169 
    170 struct DatetimeParseResult {
    171   // The absolute time in milliseconds since the epoch in UTC. This is derived
    172   // from the reference time and the fields specified in the text - so it may
    173   // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
    174   int64 time_ms_utc;
    175 
    176   // The precision of the estimate then in to calculating the milliseconds
    177   DatetimeGranularity granularity;
    178 
    179   DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
    180 
    181   DatetimeParseResult(int64 arg_time_ms_utc,
    182                       DatetimeGranularity arg_granularity)
    183       : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
    184 
    185   bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
    186 
    187   bool operator==(const DatetimeParseResult& other) const {
    188     return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
    189   }
    190 };
    191 
    192 const float kFloatCompareEpsilon = 1e-5;
    193 
    194 struct DatetimeParseResultSpan {
    195   CodepointSpan span;
    196   DatetimeParseResult data;
    197   float target_classification_score;
    198   float priority_score;
    199 
    200   bool operator==(const DatetimeParseResultSpan& other) const {
    201     return span == other.span && data.granularity == other.data.granularity &&
    202            data.time_ms_utc == other.data.time_ms_utc &&
    203            std::abs(target_classification_score -
    204                     other.target_classification_score) < kFloatCompareEpsilon &&
    205            std::abs(priority_score - other.priority_score) <
    206                kFloatCompareEpsilon;
    207   }
    208 };
    209 
    210 // Pretty-printing function for DatetimeParseResultSpan.
    211 inline logging::LoggingStringStream& operator<<(
    212     logging::LoggingStringStream& stream,
    213     const DatetimeParseResultSpan& value) {
    214   return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
    215                 << value.span.second << "}, {/*time_ms_utc=*/ "
    216                 << value.data.time_ms_utc << ", /*granularity=*/ "
    217                 << value.data.granularity << "})";
    218 }
    219 
    220 struct ClassificationResult {
    221   std::string collection;
    222   float score;
    223   DatetimeParseResult datetime_parse_result;
    224 
    225   // Internal score used for conflict resolution.
    226   float priority_score;
    227 
    228   explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
    229 
    230   ClassificationResult(const std::string& arg_collection, float arg_score)
    231       : collection(arg_collection),
    232         score(arg_score),
    233         priority_score(arg_score) {}
    234 
    235   ClassificationResult(const std::string& arg_collection, float arg_score,
    236                        float arg_priority_score)
    237       : collection(arg_collection),
    238         score(arg_score),
    239         priority_score(arg_priority_score) {}
    240 };
    241 
    242 // Pretty-printing function for ClassificationResult.
    243 inline logging::LoggingStringStream& operator<<(
    244     logging::LoggingStringStream& stream, const ClassificationResult& result) {
    245   return stream << "ClassificationResult(" << result.collection << ", "
    246                 << result.score << ")";
    247 }
    248 
    249 // Pretty-printing function for std::vector<ClassificationResult>.
    250 inline logging::LoggingStringStream& operator<<(
    251     logging::LoggingStringStream& stream,
    252     const std::vector<ClassificationResult>& results) {
    253   stream = stream << "{\n";
    254   for (const ClassificationResult& result : results) {
    255     stream = stream << "    " << result << "\n";
    256   }
    257   stream = stream << "}";
    258   return stream;
    259 }
    260 
    261 // Represents a result of Annotate call.
    262 struct AnnotatedSpan {
    263   // Unicode codepoint indices in the input string.
    264   CodepointSpan span = {kInvalidIndex, kInvalidIndex};
    265 
    266   // Classification result for the span.
    267   std::vector<ClassificationResult> classification;
    268 };
    269 
    270 // Pretty-printing function for AnnotatedSpan.
    271 inline logging::LoggingStringStream& operator<<(
    272     logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
    273   std::string best_class;
    274   float best_score = -1;
    275   if (!span.classification.empty()) {
    276     best_class = span.classification[0].collection;
    277     best_score = span.classification[0].score;
    278   }
    279   return stream << "Span(" << span.span.first << ", " << span.span.second
    280                 << ", " << best_class << ", " << best_score << ")";
    281 }
    282 
    283 // StringPiece analogue for std::vector<T>.
    284 template <class T>
    285 class VectorSpan {
    286  public:
    287   VectorSpan() : begin_(), end_() {}
    288   VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
    289       : begin_(v.begin()), end_(v.end()) {}
    290   VectorSpan(typename std::vector<T>::const_iterator begin,
    291              typename std::vector<T>::const_iterator end)
    292       : begin_(begin), end_(end) {}
    293 
    294   const T& operator[](typename std::vector<T>::size_type i) const {
    295     return *(begin_ + i);
    296   }
    297 
    298   int size() const { return end_ - begin_; }
    299   typename std::vector<T>::const_iterator begin() const { return begin_; }
    300   typename std::vector<T>::const_iterator end() const { return end_; }
    301   const float* data() const { return &(*begin_); }
    302 
    303  private:
    304   typename std::vector<T>::const_iterator begin_;
    305   typename std::vector<T>::const_iterator end_;
    306 };
    307 
    308 struct DateParseData {
    309   enum Relation {
    310     NEXT = 1,
    311     NEXT_OR_SAME = 2,
    312     LAST = 3,
    313     NOW = 4,
    314     TOMORROW = 5,
    315     YESTERDAY = 6,
    316     PAST = 7,
    317     FUTURE = 8
    318   };
    319 
    320   enum RelationType {
    321     MONDAY = 1,
    322     TUESDAY = 2,
    323     WEDNESDAY = 3,
    324     THURSDAY = 4,
    325     FRIDAY = 5,
    326     SATURDAY = 6,
    327     SUNDAY = 7,
    328     DAY = 8,
    329     WEEK = 9,
    330     MONTH = 10,
    331     YEAR = 11
    332   };
    333 
    334   enum Fields {
    335     YEAR_FIELD = 1 << 0,
    336     MONTH_FIELD = 1 << 1,
    337     DAY_FIELD = 1 << 2,
    338     HOUR_FIELD = 1 << 3,
    339     MINUTE_FIELD = 1 << 4,
    340     SECOND_FIELD = 1 << 5,
    341     AMPM_FIELD = 1 << 6,
    342     ZONE_OFFSET_FIELD = 1 << 7,
    343     DST_OFFSET_FIELD = 1 << 8,
    344     RELATION_FIELD = 1 << 9,
    345     RELATION_TYPE_FIELD = 1 << 10,
    346     RELATION_DISTANCE_FIELD = 1 << 11
    347   };
    348 
    349   enum AMPM { AM = 0, PM = 1 };
    350 
    351   enum TimeUnit {
    352     DAYS = 1,
    353     WEEKS = 2,
    354     MONTHS = 3,
    355     HOURS = 4,
    356     MINUTES = 5,
    357     SECONDS = 6,
    358     YEARS = 7
    359   };
    360 
    361   // Bit mask of fields which have been set on the struct
    362   int field_set_mask;
    363 
    364   // Fields describing absolute date fields.
    365   // Year of the date seen in the text match.
    366   int year;
    367   // Month of the year starting with January = 1.
    368   int month;
    369   // Day of the month starting with 1.
    370   int day_of_month;
    371   // Hour of the day with a range of 0-23,
    372   // values less than 12 need the AMPM field below or heuristics
    373   // to definitively determine the time.
    374   int hour;
    375   // Hour of the day with a range of 0-59.
    376   int minute;
    377   // Hour of the day with a range of 0-59.
    378   int second;
    379   // 0 == AM, 1 == PM
    380   int ampm;
    381   // Number of hours offset from UTC this date time is in.
    382   int zone_offset;
    383   // Number of hours offest for DST
    384   int dst_offset;
    385 
    386   // The permutation from now that was made to find the date time.
    387   Relation relation;
    388   // The unit of measure of the change to the date time.
    389   RelationType relation_type;
    390   // The number of units of change that were made.
    391   int relation_distance;
    392 };
    393 
    394 }  // namespace libtextclassifier2
    395 
    396 #endif  // LIBTEXTCLASSIFIER_TYPES_H_
    397