1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ 18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ 19 20 #include <time.h> 21 #include <algorithm> 22 #include <cmath> 23 #include <functional> 24 #include <map> 25 #include <set> 26 #include <string> 27 #include <utility> 28 #include <vector> 29 30 #include "annotator/entity-data_generated.h" 31 #include "utils/base/integral_types.h" 32 #include "utils/base/logging.h" 33 #include "utils/flatbuffers.h" 34 #include "utils/variant.h" 35 36 namespace libtextclassifier3 { 37 38 constexpr int kInvalidIndex = -1; 39 40 // Index for a 0-based array of tokens. 41 using TokenIndex = int; 42 43 // Index for a 0-based array of codepoints. 44 using CodepointIndex = int; 45 46 // Marks a span in a sequence of codepoints. The first element is the index of 47 // the first codepoint of the span, and the second element is the index of the 48 // codepoint one past the end of the span. 49 // TODO(b/71982294): Make it a struct. 50 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; 51 52 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { 53 return a.first < b.second && b.first < a.second; 54 } 55 56 inline bool ValidNonEmptySpan(const CodepointSpan& span) { 57 return span.first < span.second && span.first >= 0 && span.second >= 0; 58 } 59 60 template <typename T> 61 bool DoesCandidateConflict( 62 const int considered_candidate, const std::vector<T>& candidates, 63 const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { 64 if (chosen_indices_set.empty()) { 65 return false; 66 } 67 68 auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); 69 // Check conflict on the right. 70 if (conflicting_it != chosen_indices_set.end() && 71 SpansOverlap(candidates[considered_candidate].span, 72 candidates[*conflicting_it].span)) { 73 return true; 74 } 75 76 // Check conflict on the left. 77 // If we can't go more left, there can't be a conflict: 78 if (conflicting_it == chosen_indices_set.begin()) { 79 return false; 80 } 81 // Otherwise move one span left and insert if it doesn't overlap with the 82 // candidate. 83 --conflicting_it; 84 if (!SpansOverlap(candidates[considered_candidate].span, 85 candidates[*conflicting_it].span)) { 86 return false; 87 } 88 89 return true; 90 } 91 92 // Marks a span in a sequence of tokens. The first element is the index of the 93 // first token in the span, and the second element is the index of the token one 94 // past the end of the span. 95 // TODO(b/71982294): Make it a struct. 96 using TokenSpan = std::pair<TokenIndex, TokenIndex>; 97 98 // Returns the size of the token span. Assumes that the span is valid. 99 inline int TokenSpanSize(const TokenSpan& token_span) { 100 return token_span.second - token_span.first; 101 } 102 103 // Returns a token span consisting of one token. 104 inline TokenSpan SingleTokenSpan(int token_index) { 105 return {token_index, token_index + 1}; 106 } 107 108 // Returns an intersection of two token spans. Assumes that both spans are valid 109 // and overlapping. 110 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, 111 const TokenSpan& token_span2) { 112 return {std::max(token_span1.first, token_span2.first), 113 std::min(token_span1.second, token_span2.second)}; 114 } 115 116 // Returns and expanded token span by adding a certain number of tokens on its 117 // left and on its right. 118 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, 119 int num_tokens_left, int num_tokens_right) { 120 return {token_span.first - num_tokens_left, 121 token_span.second + num_tokens_right}; 122 } 123 124 // Token holds a token, its position in the original string and whether it was 125 // part of the input span. 126 struct Token { 127 std::string value; 128 CodepointIndex start; 129 CodepointIndex end; 130 131 // Whether the token is a padding token. 132 bool is_padding; 133 134 // Default constructor constructs the padding-token. 135 Token() 136 : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} 137 138 Token(const std::string& arg_value, CodepointIndex arg_start, 139 CodepointIndex arg_end) 140 : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} 141 142 bool operator==(const Token& other) const { 143 return value == other.value && start == other.start && end == other.end && 144 is_padding == other.is_padding; 145 } 146 147 bool IsContainedInSpan(CodepointSpan span) const { 148 return start >= span.first && end <= span.second; 149 } 150 }; 151 152 // Pretty-printing function for Token. 153 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, 154 const Token& token); 155 156 enum DatetimeGranularity { 157 GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this 158 // structure being uninitialized. 159 GRANULARITY_YEAR = 0, 160 GRANULARITY_MONTH = 1, 161 GRANULARITY_WEEK = 2, 162 GRANULARITY_DAY = 3, 163 GRANULARITY_HOUR = 4, 164 GRANULARITY_MINUTE = 5, 165 GRANULARITY_SECOND = 6 166 }; 167 168 struct DatetimeParseResult { 169 // The absolute time in milliseconds since the epoch in UTC. 170 int64 time_ms_utc; 171 172 // The precision of the estimate then in to calculating the milliseconds 173 DatetimeGranularity granularity; 174 175 DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} 176 177 DatetimeParseResult(int64 arg_time_ms_utc, 178 DatetimeGranularity arg_granularity) 179 : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} 180 181 bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } 182 183 bool operator==(const DatetimeParseResult& other) const { 184 return granularity == other.granularity && time_ms_utc == other.time_ms_utc; 185 } 186 }; 187 188 const float kFloatCompareEpsilon = 1e-5; 189 190 struct DatetimeParseResultSpan { 191 CodepointSpan span; 192 std::vector<DatetimeParseResult> data; 193 float target_classification_score; 194 float priority_score; 195 196 bool operator==(const DatetimeParseResultSpan& other) const { 197 return span == other.span && data == other.data && 198 std::abs(target_classification_score - 199 other.target_classification_score) < kFloatCompareEpsilon && 200 std::abs(priority_score - other.priority_score) < 201 kFloatCompareEpsilon; 202 } 203 }; 204 205 // Pretty-printing function for DatetimeParseResultSpan. 206 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, 207 const DatetimeParseResultSpan& value); 208 209 struct ClassificationResult { 210 std::string collection; 211 float score; 212 DatetimeParseResult datetime_parse_result; 213 std::string serialized_knowledge_result; 214 std::string contact_name, contact_given_name, contact_nickname, 215 contact_email_address, contact_phone_number, contact_id; 216 std::string app_name, app_package_name; 217 int64 numeric_value; 218 219 // Length of the parsed duration in milliseconds. 220 int64 duration_ms; 221 222 // Internal score used for conflict resolution. 223 float priority_score; 224 225 226 // Entity data information. 227 std::string serialized_entity_data; 228 const EntityData* entity_data() { 229 return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(), 230 serialized_entity_data.size()); 231 } 232 233 explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} 234 235 ClassificationResult(const std::string& arg_collection, float arg_score) 236 : collection(arg_collection), 237 score(arg_score), 238 priority_score(arg_score) {} 239 240 ClassificationResult(const std::string& arg_collection, float arg_score, 241 float arg_priority_score) 242 : collection(arg_collection), 243 score(arg_score), 244 priority_score(arg_priority_score) {} 245 }; 246 247 // Pretty-printing function for ClassificationResult. 248 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, 249 const ClassificationResult& result); 250 251 // Pretty-printing function for std::vector<ClassificationResult>. 252 logging::LoggingStringStream& operator<<( 253 logging::LoggingStringStream& stream, 254 const std::vector<ClassificationResult>& results); 255 256 // Represents a result of Annotate call. 257 struct AnnotatedSpan { 258 enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME }; 259 260 // Unicode codepoint indices in the input string. 261 CodepointSpan span = {kInvalidIndex, kInvalidIndex}; 262 263 // Classification result for the span. 264 std::vector<ClassificationResult> classification; 265 266 // The source of the annotation, used in conflict resolution. 267 Source source = Source::OTHER; 268 269 AnnotatedSpan() = default; 270 271 AnnotatedSpan(CodepointSpan arg_span, 272 std::vector<ClassificationResult> arg_classification) 273 : span(arg_span), classification(std::move(arg_classification)) {} 274 }; 275 276 // Pretty-printing function for AnnotatedSpan. 277 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, 278 const AnnotatedSpan& span); 279 280 // StringPiece analogue for std::vector<T>. 281 template <class T> 282 class VectorSpan { 283 public: 284 VectorSpan() : begin_(), end_() {} 285 VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) 286 : begin_(v.begin()), end_(v.end()) {} 287 VectorSpan(typename std::vector<T>::const_iterator begin, 288 typename std::vector<T>::const_iterator end) 289 : begin_(begin), end_(end) {} 290 291 const T& operator[](typename std::vector<T>::size_type i) const { 292 return *(begin_ + i); 293 } 294 295 int size() const { return end_ - begin_; } 296 typename std::vector<T>::const_iterator begin() const { return begin_; } 297 typename std::vector<T>::const_iterator end() const { return end_; } 298 const float* data() const { return &(*begin_); } 299 300 private: 301 typename std::vector<T>::const_iterator begin_; 302 typename std::vector<T>::const_iterator end_; 303 }; 304 305 struct DateParseData { 306 enum class Relation { 307 UNSPECIFIED = 0, 308 NEXT = 1, 309 NEXT_OR_SAME = 2, 310 LAST = 3, 311 NOW = 4, 312 TOMORROW = 5, 313 YESTERDAY = 6, 314 PAST = 7, 315 FUTURE = 8 316 }; 317 318 enum class RelationType { 319 UNSPECIFIED = 0, 320 SUNDAY = 1, 321 MONDAY = 2, 322 TUESDAY = 3, 323 WEDNESDAY = 4, 324 THURSDAY = 5, 325 FRIDAY = 6, 326 SATURDAY = 7, 327 DAY = 8, 328 WEEK = 9, 329 MONTH = 10, 330 YEAR = 11, 331 HOUR = 12, 332 MINUTE = 13, 333 SECOND = 14, 334 }; 335 336 enum Fields { 337 YEAR_FIELD = 1 << 0, 338 MONTH_FIELD = 1 << 1, 339 DAY_FIELD = 1 << 2, 340 HOUR_FIELD = 1 << 3, 341 MINUTE_FIELD = 1 << 4, 342 SECOND_FIELD = 1 << 5, 343 AMPM_FIELD = 1 << 6, 344 ZONE_OFFSET_FIELD = 1 << 7, 345 DST_OFFSET_FIELD = 1 << 8, 346 RELATION_FIELD = 1 << 9, 347 RELATION_TYPE_FIELD = 1 << 10, 348 RELATION_DISTANCE_FIELD = 1 << 11 349 }; 350 351 enum class AMPM { AM = 0, PM = 1 }; 352 353 enum class TimeUnit { 354 DAYS = 1, 355 WEEKS = 2, 356 MONTHS = 3, 357 HOURS = 4, 358 MINUTES = 5, 359 SECONDS = 6, 360 YEARS = 7 361 }; 362 363 // Bit mask of fields which have been set on the struct 364 int field_set_mask = 0; 365 366 // Fields describing absolute date fields. 367 // Year of the date seen in the text match. 368 int year = 0; 369 // Month of the year starting with January = 1. 370 int month = 0; 371 // Day of the month starting with 1. 372 int day_of_month = 0; 373 // Hour of the day with a range of 0-23, 374 // values less than 12 need the AMPM field below or heuristics 375 // to definitively determine the time. 376 int hour = 0; 377 // Hour of the day with a range of 0-59. 378 int minute = 0; 379 // Hour of the day with a range of 0-59. 380 int second = 0; 381 // 0 == AM, 1 == PM 382 AMPM ampm = AMPM::AM; 383 // Number of hours offset from UTC this date time is in. 384 int zone_offset = 0; 385 // Number of hours offest for DST 386 int dst_offset = 0; 387 388 // The permutation from now that was made to find the date time. 389 Relation relation = Relation::UNSPECIFIED; 390 // The unit of measure of the change to the date time. 391 RelationType relation_type = RelationType::UNSPECIFIED; 392 // The number of units of change that were made. 393 int relation_distance = 0; 394 395 DateParseData() = default; 396 397 DateParseData(int field_set_mask, int year, int month, int day_of_month, 398 int hour, int minute, int second, AMPM ampm, int zone_offset, 399 int dst_offset, Relation relation, RelationType relation_type, 400 int relation_distance) { 401 this->field_set_mask = field_set_mask; 402 this->year = year; 403 this->month = month; 404 this->day_of_month = day_of_month; 405 this->hour = hour; 406 this->minute = minute; 407 this->second = second; 408 this->ampm = ampm; 409 this->zone_offset = zone_offset; 410 this->dst_offset = dst_offset; 411 this->relation = relation; 412 this->relation_type = relation_type; 413 this->relation_distance = relation_distance; 414 } 415 }; 416 417 // Pretty-printing function for DateParseData. 418 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, 419 const DateParseData& data); 420 421 } // namespace libtextclassifier3 422 423 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ 424