1 // float-weight.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Float weight set and associated semiring operation definitions. 20 // 21 22 #ifndef FST_LIB_FLOAT_WEIGHT_H__ 23 #define FST_LIB_FLOAT_WEIGHT_H__ 24 25 #include <limits> 26 #include <climits> 27 #include <sstream> 28 #include <string> 29 30 #include <fst/util.h> 31 #include <fst/weight.h> 32 33 34 namespace fst { 35 36 // numeric limits class 37 template <class T> 38 class FloatLimits { 39 public: 40 static const T kPosInfinity; 41 static const T kNegInfinity; 42 static const T kNumberBad; 43 }; 44 45 template <class T> 46 const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity(); 47 48 template <class T> 49 const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity; 50 51 template <class T> 52 const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN(); 53 54 // weight class to be templated on floating-points types 55 template <class T = float> 56 class FloatWeightTpl { 57 public: 58 FloatWeightTpl() {} 59 60 FloatWeightTpl(T f) : value_(f) {} 61 62 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} 63 64 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { 65 value_ = w.value_; 66 return *this; 67 } 68 69 istream &Read(istream &strm) { 70 return ReadType(strm, &value_); 71 } 72 73 ostream &Write(ostream &strm) const { 74 return WriteType(strm, value_); 75 } 76 77 size_t Hash() const { 78 union { 79 T f; 80 size_t s; 81 } u; 82 u.s = 0; 83 u.f = value_; 84 return u.s; 85 } 86 87 const T &Value() const { return value_; } 88 89 protected: 90 void SetValue(const T &f) { value_ = f; } 91 92 inline static string GetPrecisionString() { 93 int64 size = sizeof(T); 94 if (size == sizeof(float)) return ""; 95 size *= CHAR_BIT; 96 97 string result; 98 Int64ToStr(size, &result); 99 return result; 100 } 101 102 private: 103 T value_; 104 }; 105 106 // Single-precision float weight 107 typedef FloatWeightTpl<float> FloatWeight; 108 109 template <class T> 110 inline bool operator==(const FloatWeightTpl<T> &w1, 111 const FloatWeightTpl<T> &w2) { 112 // Volatile qualifier thwarts over-aggressive compiler optimizations 113 // that lead to problems esp. with NaturalLess(). 114 volatile T v1 = w1.Value(); 115 volatile T v2 = w2.Value(); 116 return v1 == v2; 117 } 118 119 inline bool operator==(const FloatWeightTpl<double> &w1, 120 const FloatWeightTpl<double> &w2) { 121 return operator==<double>(w1, w2); 122 } 123 124 inline bool operator==(const FloatWeightTpl<float> &w1, 125 const FloatWeightTpl<float> &w2) { 126 return operator==<float>(w1, w2); 127 } 128 129 template <class T> 130 inline bool operator!=(const FloatWeightTpl<T> &w1, 131 const FloatWeightTpl<T> &w2) { 132 return !(w1 == w2); 133 } 134 135 inline bool operator!=(const FloatWeightTpl<double> &w1, 136 const FloatWeightTpl<double> &w2) { 137 return operator!=<double>(w1, w2); 138 } 139 140 inline bool operator!=(const FloatWeightTpl<float> &w1, 141 const FloatWeightTpl<float> &w2) { 142 return operator!=<float>(w1, w2); 143 } 144 145 template <class T> 146 inline bool ApproxEqual(const FloatWeightTpl<T> &w1, 147 const FloatWeightTpl<T> &w2, 148 float delta = kDelta) { 149 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; 150 } 151 152 template <class T> 153 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { 154 if (w.Value() == FloatLimits<T>::kPosInfinity) 155 return strm << "Infinity"; 156 else if (w.Value() == FloatLimits<T>::kNegInfinity) 157 return strm << "-Infinity"; 158 else if (w.Value() != w.Value()) // Fails for NaN 159 return strm << "BadNumber"; 160 else 161 return strm << w.Value(); 162 } 163 164 template <class T> 165 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { 166 string s; 167 strm >> s; 168 if (s == "Infinity") { 169 w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity); 170 } else if (s == "-Infinity") { 171 w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity); 172 } else { 173 char *p; 174 T f = strtod(s.c_str(), &p); 175 if (p < s.c_str() + s.size()) 176 strm.clear(std::ios::badbit); 177 else 178 w = FloatWeightTpl<T>(f); 179 } 180 return strm; 181 } 182 183 184 // Tropical semiring: (min, +, inf, 0) 185 template <class T> 186 class TropicalWeightTpl : public FloatWeightTpl<T> { 187 public: 188 using FloatWeightTpl<T>::Value; 189 190 typedef TropicalWeightTpl<T> ReverseWeight; 191 192 TropicalWeightTpl() : FloatWeightTpl<T>() {} 193 194 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} 195 196 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 197 198 static const TropicalWeightTpl<T> Zero() { 199 return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); } 200 201 static const TropicalWeightTpl<T> One() { 202 return TropicalWeightTpl<T>(0.0F); } 203 204 static const TropicalWeightTpl<T> NoWeight() { 205 return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); } 206 207 static const string &Type() { 208 static const string type = "tropical" + 209 FloatWeightTpl<T>::GetPrecisionString(); 210 return type; 211 } 212 213 bool Member() const { 214 // First part fails for IEEE NaN 215 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; 216 } 217 218 TropicalWeightTpl<T> Quantize(float delta = kDelta) const { 219 if (Value() == FloatLimits<T>::kNegInfinity || 220 Value() == FloatLimits<T>::kPosInfinity || 221 Value() != Value()) 222 return *this; 223 else 224 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 225 } 226 227 TropicalWeightTpl<T> Reverse() const { return *this; } 228 229 static uint64 Properties() { 230 return kLeftSemiring | kRightSemiring | kCommutative | 231 kPath | kIdempotent; 232 } 233 }; 234 235 // Single precision tropical weight 236 typedef TropicalWeightTpl<float> TropicalWeight; 237 238 template <class T> 239 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, 240 const TropicalWeightTpl<T> &w2) { 241 if (!w1.Member() || !w2.Member()) 242 return TropicalWeightTpl<T>::NoWeight(); 243 return w1.Value() < w2.Value() ? w1 : w2; 244 } 245 246 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, 247 const TropicalWeightTpl<float> &w2) { 248 return Plus<float>(w1, w2); 249 } 250 251 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, 252 const TropicalWeightTpl<double> &w2) { 253 return Plus<double>(w1, w2); 254 } 255 256 template <class T> 257 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, 258 const TropicalWeightTpl<T> &w2) { 259 if (!w1.Member() || !w2.Member()) 260 return TropicalWeightTpl<T>::NoWeight(); 261 T f1 = w1.Value(), f2 = w2.Value(); 262 if (f1 == FloatLimits<T>::kPosInfinity) 263 return w1; 264 else if (f2 == FloatLimits<T>::kPosInfinity) 265 return w2; 266 else 267 return TropicalWeightTpl<T>(f1 + f2); 268 } 269 270 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, 271 const TropicalWeightTpl<float> &w2) { 272 return Times<float>(w1, w2); 273 } 274 275 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, 276 const TropicalWeightTpl<double> &w2) { 277 return Times<double>(w1, w2); 278 } 279 280 template <class T> 281 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, 282 const TropicalWeightTpl<T> &w2, 283 DivideType typ = DIVIDE_ANY) { 284 if (!w1.Member() || !w2.Member()) 285 return TropicalWeightTpl<T>::NoWeight(); 286 T f1 = w1.Value(), f2 = w2.Value(); 287 if (f2 == FloatLimits<T>::kPosInfinity) 288 return FloatLimits<T>::kNumberBad; 289 else if (f1 == FloatLimits<T>::kPosInfinity) 290 return FloatLimits<T>::kPosInfinity; 291 else 292 return TropicalWeightTpl<T>(f1 - f2); 293 } 294 295 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, 296 const TropicalWeightTpl<float> &w2, 297 DivideType typ = DIVIDE_ANY) { 298 return Divide<float>(w1, w2, typ); 299 } 300 301 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, 302 const TropicalWeightTpl<double> &w2, 303 DivideType typ = DIVIDE_ANY) { 304 return Divide<double>(w1, w2, typ); 305 } 306 307 308 // Log semiring: (log(e^-x + e^y), +, inf, 0) 309 template <class T> 310 class LogWeightTpl : public FloatWeightTpl<T> { 311 public: 312 using FloatWeightTpl<T>::Value; 313 314 typedef LogWeightTpl ReverseWeight; 315 316 LogWeightTpl() : FloatWeightTpl<T>() {} 317 318 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} 319 320 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 321 322 static const LogWeightTpl<T> Zero() { 323 return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity); 324 } 325 326 static const LogWeightTpl<T> One() { 327 return LogWeightTpl<T>(0.0F); 328 } 329 330 static const LogWeightTpl<T> NoWeight() { 331 return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); } 332 333 static const string &Type() { 334 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); 335 return type; 336 } 337 338 bool Member() const { 339 // First part fails for IEEE NaN 340 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; 341 } 342 343 LogWeightTpl<T> Quantize(float delta = kDelta) const { 344 if (Value() == FloatLimits<T>::kNegInfinity || 345 Value() == FloatLimits<T>::kPosInfinity || 346 Value() != Value()) 347 return *this; 348 else 349 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 350 } 351 352 LogWeightTpl<T> Reverse() const { return *this; } 353 354 static uint64 Properties() { 355 return kLeftSemiring | kRightSemiring | kCommutative; 356 } 357 }; 358 359 // Single-precision log weight 360 typedef LogWeightTpl<float> LogWeight; 361 // Double-precision log weight 362 typedef LogWeightTpl<double> Log64Weight; 363 364 template <class T> 365 inline T LogExp(T x) { return log(1.0F + exp(-x)); } 366 367 template <class T> 368 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, 369 const LogWeightTpl<T> &w2) { 370 T f1 = w1.Value(), f2 = w2.Value(); 371 if (f1 == FloatLimits<T>::kPosInfinity) 372 return w2; 373 else if (f2 == FloatLimits<T>::kPosInfinity) 374 return w1; 375 else if (f1 > f2) 376 return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); 377 else 378 return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); 379 } 380 381 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, 382 const LogWeightTpl<float> &w2) { 383 return Plus<float>(w1, w2); 384 } 385 386 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, 387 const LogWeightTpl<double> &w2) { 388 return Plus<double>(w1, w2); 389 } 390 391 template <class T> 392 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, 393 const LogWeightTpl<T> &w2) { 394 if (!w1.Member() || !w2.Member()) 395 return LogWeightTpl<T>::NoWeight(); 396 T f1 = w1.Value(), f2 = w2.Value(); 397 if (f1 == FloatLimits<T>::kPosInfinity) 398 return w1; 399 else if (f2 == FloatLimits<T>::kPosInfinity) 400 return w2; 401 else 402 return LogWeightTpl<T>(f1 + f2); 403 } 404 405 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, 406 const LogWeightTpl<float> &w2) { 407 return Times<float>(w1, w2); 408 } 409 410 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, 411 const LogWeightTpl<double> &w2) { 412 return Times<double>(w1, w2); 413 } 414 415 template <class T> 416 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, 417 const LogWeightTpl<T> &w2, 418 DivideType typ = DIVIDE_ANY) { 419 if (!w1.Member() || !w2.Member()) 420 return LogWeightTpl<T>::NoWeight(); 421 T f1 = w1.Value(), f2 = w2.Value(); 422 if (f2 == FloatLimits<T>::kPosInfinity) 423 return FloatLimits<T>::kNumberBad; 424 else if (f1 == FloatLimits<T>::kPosInfinity) 425 return FloatLimits<T>::kPosInfinity; 426 else 427 return LogWeightTpl<T>(f1 - f2); 428 } 429 430 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, 431 const LogWeightTpl<float> &w2, 432 DivideType typ = DIVIDE_ANY) { 433 return Divide<float>(w1, w2, typ); 434 } 435 436 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, 437 const LogWeightTpl<double> &w2, 438 DivideType typ = DIVIDE_ANY) { 439 return Divide<double>(w1, w2, typ); 440 } 441 442 // MinMax semiring: (min, max, inf, -inf) 443 template <class T> 444 class MinMaxWeightTpl : public FloatWeightTpl<T> { 445 public: 446 using FloatWeightTpl<T>::Value; 447 448 typedef MinMaxWeightTpl<T> ReverseWeight; 449 450 MinMaxWeightTpl() : FloatWeightTpl<T>() {} 451 452 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} 453 454 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 455 456 static const MinMaxWeightTpl<T> Zero() { 457 return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity); 458 } 459 460 static const MinMaxWeightTpl<T> One() { 461 return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity); 462 } 463 464 static const MinMaxWeightTpl<T> NoWeight() { 465 return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); } 466 467 static const string &Type() { 468 static const string type = "minmax" + 469 FloatWeightTpl<T>::GetPrecisionString(); 470 return type; 471 } 472 473 bool Member() const { 474 // Fails for IEEE NaN 475 return Value() == Value(); 476 } 477 478 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { 479 // If one of infinities, or a NaN 480 if (Value() == FloatLimits<T>::kNegInfinity || 481 Value() == FloatLimits<T>::kPosInfinity || 482 Value() != Value()) 483 return *this; 484 else 485 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 486 } 487 488 MinMaxWeightTpl<T> Reverse() const { return *this; } 489 490 static uint64 Properties() { 491 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; 492 } 493 }; 494 495 // Single-precision min-max weight 496 typedef MinMaxWeightTpl<float> MinMaxWeight; 497 498 // Min 499 template <class T> 500 inline MinMaxWeightTpl<T> Plus( 501 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 502 if (!w1.Member() || !w2.Member()) 503 return MinMaxWeightTpl<T>::NoWeight(); 504 return w1.Value() < w2.Value() ? w1 : w2; 505 } 506 507 inline MinMaxWeightTpl<float> Plus( 508 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 509 return Plus<float>(w1, w2); 510 } 511 512 inline MinMaxWeightTpl<double> Plus( 513 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 514 return Plus<double>(w1, w2); 515 } 516 517 // Max 518 template <class T> 519 inline MinMaxWeightTpl<T> Times( 520 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 521 if (!w1.Member() || !w2.Member()) 522 return MinMaxWeightTpl<T>::NoWeight(); 523 return w1.Value() >= w2.Value() ? w1 : w2; 524 } 525 526 inline MinMaxWeightTpl<float> Times( 527 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 528 return Times<float>(w1, w2); 529 } 530 531 inline MinMaxWeightTpl<double> Times( 532 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 533 return Times<double>(w1, w2); 534 } 535 536 // Defined only for special cases 537 template <class T> 538 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, 539 const MinMaxWeightTpl<T> &w2, 540 DivideType typ = DIVIDE_ANY) { 541 if (!w1.Member() || !w2.Member()) 542 return MinMaxWeightTpl<T>::NoWeight(); 543 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 544 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad; 545 } 546 547 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, 548 const MinMaxWeightTpl<float> &w2, 549 DivideType typ = DIVIDE_ANY) { 550 return Divide<float>(w1, w2, typ); 551 } 552 553 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, 554 const MinMaxWeightTpl<double> &w2, 555 DivideType typ = DIVIDE_ANY) { 556 return Divide<double>(w1, w2, typ); 557 } 558 559 // 560 // WEIGHT CONVERTER SPECIALIZATIONS. 561 // 562 563 // Convert to tropical 564 template <> 565 struct WeightConvert<LogWeight, TropicalWeight> { 566 TropicalWeight operator()(LogWeight w) const { return w.Value(); } 567 }; 568 569 template <> 570 struct WeightConvert<Log64Weight, TropicalWeight> { 571 TropicalWeight operator()(Log64Weight w) const { return w.Value(); } 572 }; 573 574 // Convert to log 575 template <> 576 struct WeightConvert<TropicalWeight, LogWeight> { 577 LogWeight operator()(TropicalWeight w) const { return w.Value(); } 578 }; 579 580 template <> 581 struct WeightConvert<Log64Weight, LogWeight> { 582 LogWeight operator()(Log64Weight w) const { return w.Value(); } 583 }; 584 585 // Convert to log64 586 template <> 587 struct WeightConvert<TropicalWeight, Log64Weight> { 588 Log64Weight operator()(TropicalWeight w) const { return w.Value(); } 589 }; 590 591 template <> 592 struct WeightConvert<LogWeight, Log64Weight> { 593 Log64Weight operator()(LogWeight w) const { return w.Value(); } 594 }; 595 596 } // namespace fst 597 598 #endif // FST_LIB_FLOAT_WEIGHT_H__ 599