1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_TYPES_H_ 18 #define LIBTEXTCLASSIFIER_TYPES_H_ 19 20 #include <algorithm> 21 #include <cmath> 22 #include <functional> 23 #include <set> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include "util/base/integral_types.h" 28 29 #include "util/base/logging.h" 30 31 namespace libtextclassifier2 { 32 33 constexpr int kInvalidIndex = -1; 34 35 // Index for a 0-based array of tokens. 36 using TokenIndex = int; 37 38 // Index for a 0-based array of codepoints. 39 using CodepointIndex = int; 40 41 // Marks a span in a sequence of codepoints. The first element is the index of 42 // the first codepoint of the span, and the second element is the index of the 43 // codepoint one past the end of the span. 44 // TODO(b/71982294): Make it a struct. 45 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; 46 47 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { 48 return a.first < b.second && b.first < a.second; 49 } 50 51 inline bool ValidNonEmptySpan(const CodepointSpan& span) { 52 return span.first < span.second && span.first >= 0 && span.second >= 0; 53 } 54 55 template <typename T> 56 bool DoesCandidateConflict( 57 const int considered_candidate, const std::vector<T>& candidates, 58 const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { 59 if (chosen_indices_set.empty()) { 60 return false; 61 } 62 63 auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); 64 // Check conflict on the right. 65 if (conflicting_it != chosen_indices_set.end() && 66 SpansOverlap(candidates[considered_candidate].span, 67 candidates[*conflicting_it].span)) { 68 return true; 69 } 70 71 // Check conflict on the left. 72 // If we can't go more left, there can't be a conflict: 73 if (conflicting_it == chosen_indices_set.begin()) { 74 return false; 75 } 76 // Otherwise move one span left and insert if it doesn't overlap with the 77 // candidate. 78 --conflicting_it; 79 if (!SpansOverlap(candidates[considered_candidate].span, 80 candidates[*conflicting_it].span)) { 81 return false; 82 } 83 84 return true; 85 } 86 87 // Marks a span in a sequence of tokens. The first element is the index of the 88 // first token in the span, and the second element is the index of the token one 89 // past the end of the span. 90 // TODO(b/71982294): Make it a struct. 91 using TokenSpan = std::pair<TokenIndex, TokenIndex>; 92 93 // Returns the size of the token span. Assumes that the span is valid. 94 inline int TokenSpanSize(const TokenSpan& token_span) { 95 return token_span.second - token_span.first; 96 } 97 98 // Returns a token span consisting of one token. 99 inline TokenSpan SingleTokenSpan(int token_index) { 100 return {token_index, token_index + 1}; 101 } 102 103 // Returns an intersection of two token spans. Assumes that both spans are valid 104 // and overlapping. 105 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, 106 const TokenSpan& token_span2) { 107 return {std::max(token_span1.first, token_span2.first), 108 std::min(token_span1.second, token_span2.second)}; 109 } 110 111 // Returns and expanded token span by adding a certain number of tokens on its 112 // left and on its right. 113 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, 114 int num_tokens_left, int num_tokens_right) { 115 return {token_span.first - num_tokens_left, 116 token_span.second + num_tokens_right}; 117 } 118 119 // Token holds a token, its position in the original string and whether it was 120 // part of the input span. 121 struct Token { 122 std::string value; 123 CodepointIndex start; 124 CodepointIndex end; 125 126 // Whether the token is a padding token. 127 bool is_padding; 128 129 // Default constructor constructs the padding-token. 130 Token() 131 : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} 132 133 Token(const std::string& arg_value, CodepointIndex arg_start, 134 CodepointIndex arg_end) 135 : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} 136 137 bool operator==(const Token& other) const { 138 return value == other.value && start == other.start && end == other.end && 139 is_padding == other.is_padding; 140 } 141 142 bool IsContainedInSpan(CodepointSpan span) const { 143 return start >= span.first && end <= span.second; 144 } 145 }; 146 147 // Pretty-printing function for Token. 148 inline logging::LoggingStringStream& operator<<( 149 logging::LoggingStringStream& stream, const Token& token) { 150 if (!token.is_padding) { 151 return stream << "Token(\"" << token.value << "\", " << token.start << ", " 152 << token.end << ")"; 153 } else { 154 return stream << "Token()"; 155 } 156 } 157 158 enum DatetimeGranularity { 159 GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this 160 // structure being uninitialized. 161 GRANULARITY_YEAR = 0, 162 GRANULARITY_MONTH = 1, 163 GRANULARITY_WEEK = 2, 164 GRANULARITY_DAY = 3, 165 GRANULARITY_HOUR = 4, 166 GRANULARITY_MINUTE = 5, 167 GRANULARITY_SECOND = 6 168 }; 169 170 struct DatetimeParseResult { 171 // The absolute time in milliseconds since the epoch in UTC. This is derived 172 // from the reference time and the fields specified in the text - so it may 173 // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm) 174 int64 time_ms_utc; 175 176 // The precision of the estimate then in to calculating the milliseconds 177 DatetimeGranularity granularity; 178 179 DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} 180 181 DatetimeParseResult(int64 arg_time_ms_utc, 182 DatetimeGranularity arg_granularity) 183 : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} 184 185 bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } 186 187 bool operator==(const DatetimeParseResult& other) const { 188 return granularity == other.granularity && time_ms_utc == other.time_ms_utc; 189 } 190 }; 191 192 const float kFloatCompareEpsilon = 1e-5; 193 194 struct DatetimeParseResultSpan { 195 CodepointSpan span; 196 DatetimeParseResult data; 197 float target_classification_score; 198 float priority_score; 199 200 bool operator==(const DatetimeParseResultSpan& other) const { 201 return span == other.span && data.granularity == other.data.granularity && 202 data.time_ms_utc == other.data.time_ms_utc && 203 std::abs(target_classification_score - 204 other.target_classification_score) < kFloatCompareEpsilon && 205 std::abs(priority_score - other.priority_score) < 206 kFloatCompareEpsilon; 207 } 208 }; 209 210 // Pretty-printing function for DatetimeParseResultSpan. 211 inline logging::LoggingStringStream& operator<<( 212 logging::LoggingStringStream& stream, 213 const DatetimeParseResultSpan& value) { 214 return stream << "DatetimeParseResultSpan({" << value.span.first << ", " 215 << value.span.second << "}, {/*time_ms_utc=*/ " 216 << value.data.time_ms_utc << ", /*granularity=*/ " 217 << value.data.granularity << "})"; 218 } 219 220 struct ClassificationResult { 221 std::string collection; 222 float score; 223 DatetimeParseResult datetime_parse_result; 224 225 // Internal score used for conflict resolution. 226 float priority_score; 227 228 explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} 229 230 ClassificationResult(const std::string& arg_collection, float arg_score) 231 : collection(arg_collection), 232 score(arg_score), 233 priority_score(arg_score) {} 234 235 ClassificationResult(const std::string& arg_collection, float arg_score, 236 float arg_priority_score) 237 : collection(arg_collection), 238 score(arg_score), 239 priority_score(arg_priority_score) {} 240 }; 241 242 // Pretty-printing function for ClassificationResult. 243 inline logging::LoggingStringStream& operator<<( 244 logging::LoggingStringStream& stream, const ClassificationResult& result) { 245 return stream << "ClassificationResult(" << result.collection << ", " 246 << result.score << ")"; 247 } 248 249 // Pretty-printing function for std::vector<ClassificationResult>. 250 inline logging::LoggingStringStream& operator<<( 251 logging::LoggingStringStream& stream, 252 const std::vector<ClassificationResult>& results) { 253 stream = stream << "{\n"; 254 for (const ClassificationResult& result : results) { 255 stream = stream << " " << result << "\n"; 256 } 257 stream = stream << "}"; 258 return stream; 259 } 260 261 // Represents a result of Annotate call. 262 struct AnnotatedSpan { 263 // Unicode codepoint indices in the input string. 264 CodepointSpan span = {kInvalidIndex, kInvalidIndex}; 265 266 // Classification result for the span. 267 std::vector<ClassificationResult> classification; 268 }; 269 270 // Pretty-printing function for AnnotatedSpan. 271 inline logging::LoggingStringStream& operator<<( 272 logging::LoggingStringStream& stream, const AnnotatedSpan& span) { 273 std::string best_class; 274 float best_score = -1; 275 if (!span.classification.empty()) { 276 best_class = span.classification[0].collection; 277 best_score = span.classification[0].score; 278 } 279 return stream << "Span(" << span.span.first << ", " << span.span.second 280 << ", " << best_class << ", " << best_score << ")"; 281 } 282 283 // StringPiece analogue for std::vector<T>. 284 template <class T> 285 class VectorSpan { 286 public: 287 VectorSpan() : begin_(), end_() {} 288 VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) 289 : begin_(v.begin()), end_(v.end()) {} 290 VectorSpan(typename std::vector<T>::const_iterator begin, 291 typename std::vector<T>::const_iterator end) 292 : begin_(begin), end_(end) {} 293 294 const T& operator[](typename std::vector<T>::size_type i) const { 295 return *(begin_ + i); 296 } 297 298 int size() const { return end_ - begin_; } 299 typename std::vector<T>::const_iterator begin() const { return begin_; } 300 typename std::vector<T>::const_iterator end() const { return end_; } 301 const float* data() const { return &(*begin_); } 302 303 private: 304 typename std::vector<T>::const_iterator begin_; 305 typename std::vector<T>::const_iterator end_; 306 }; 307 308 struct DateParseData { 309 enum Relation { 310 NEXT = 1, 311 NEXT_OR_SAME = 2, 312 LAST = 3, 313 NOW = 4, 314 TOMORROW = 5, 315 YESTERDAY = 6, 316 PAST = 7, 317 FUTURE = 8 318 }; 319 320 enum RelationType { 321 MONDAY = 1, 322 TUESDAY = 2, 323 WEDNESDAY = 3, 324 THURSDAY = 4, 325 FRIDAY = 5, 326 SATURDAY = 6, 327 SUNDAY = 7, 328 DAY = 8, 329 WEEK = 9, 330 MONTH = 10, 331 YEAR = 11 332 }; 333 334 enum Fields { 335 YEAR_FIELD = 1 << 0, 336 MONTH_FIELD = 1 << 1, 337 DAY_FIELD = 1 << 2, 338 HOUR_FIELD = 1 << 3, 339 MINUTE_FIELD = 1 << 4, 340 SECOND_FIELD = 1 << 5, 341 AMPM_FIELD = 1 << 6, 342 ZONE_OFFSET_FIELD = 1 << 7, 343 DST_OFFSET_FIELD = 1 << 8, 344 RELATION_FIELD = 1 << 9, 345 RELATION_TYPE_FIELD = 1 << 10, 346 RELATION_DISTANCE_FIELD = 1 << 11 347 }; 348 349 enum AMPM { AM = 0, PM = 1 }; 350 351 enum TimeUnit { 352 DAYS = 1, 353 WEEKS = 2, 354 MONTHS = 3, 355 HOURS = 4, 356 MINUTES = 5, 357 SECONDS = 6, 358 YEARS = 7 359 }; 360 361 // Bit mask of fields which have been set on the struct 362 int field_set_mask; 363 364 // Fields describing absolute date fields. 365 // Year of the date seen in the text match. 366 int year; 367 // Month of the year starting with January = 1. 368 int month; 369 // Day of the month starting with 1. 370 int day_of_month; 371 // Hour of the day with a range of 0-23, 372 // values less than 12 need the AMPM field below or heuristics 373 // to definitively determine the time. 374 int hour; 375 // Hour of the day with a range of 0-59. 376 int minute; 377 // Hour of the day with a range of 0-59. 378 int second; 379 // 0 == AM, 1 == PM 380 int ampm; 381 // Number of hours offset from UTC this date time is in. 382 int zone_offset; 383 // Number of hours offest for DST 384 int dst_offset; 385 386 // The permutation from now that was made to find the date time. 387 Relation relation; 388 // The unit of measure of the change to the date time. 389 RelationType relation_type; 390 // The number of units of change that were made. 391 int relation_distance; 392 }; 393 394 } // namespace libtextclassifier2 395 396 #endif // LIBTEXTCLASSIFIER_TYPES_H_ 397