1 // float-weight.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Float weight set and associated semiring operation definitions. 20 // 21 22 #ifndef FST_LIB_FLOAT_WEIGHT_H__ 23 #define FST_LIB_FLOAT_WEIGHT_H__ 24 25 #include <limits> 26 #include <climits> 27 #include <sstream> 28 #include <string> 29 30 #include <fst/util.h> 31 #include <fst/weight.h> 32 33 34 namespace fst { 35 36 // numeric limits class 37 template <class T> 38 class FloatLimits { 39 public: 40 static const T PosInfinity() { 41 static const T pos_infinity = numeric_limits<T>::infinity(); 42 return pos_infinity; 43 } 44 45 static const T NegInfinity() { 46 static const T neg_infinity = -PosInfinity(); 47 return neg_infinity; 48 } 49 50 static const T NumberBad() { 51 static const T number_bad = numeric_limits<T>::quiet_NaN(); 52 return number_bad; 53 } 54 55 }; 56 57 // weight class to be templated on floating-points types 58 template <class T = float> 59 class FloatWeightTpl { 60 public: 61 FloatWeightTpl() {} 62 63 FloatWeightTpl(T f) : value_(f) {} 64 65 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} 66 67 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { 68 value_ = w.value_; 69 return *this; 70 } 71 72 istream &Read(istream &strm) { 73 return ReadType(strm, &value_); 74 } 75 76 ostream &Write(ostream &strm) const { 77 return WriteType(strm, value_); 78 } 79 80 size_t Hash() const { 81 union { 82 T f; 83 size_t s; 84 } u; 85 u.s = 0; 86 u.f = value_; 87 return u.s; 88 } 89 90 const T &Value() const { return value_; } 91 92 protected: 93 void SetValue(const T &f) { value_ = f; } 94 95 inline static string GetPrecisionString() { 96 int64 size = sizeof(T); 97 if (size == sizeof(float)) return ""; 98 size *= CHAR_BIT; 99 100 string result; 101 Int64ToStr(size, &result); 102 return result; 103 } 104 105 private: 106 T value_; 107 }; 108 109 // Single-precision float weight 110 typedef FloatWeightTpl<float> FloatWeight; 111 112 template <class T> 113 inline bool operator==(const FloatWeightTpl<T> &w1, 114 const FloatWeightTpl<T> &w2) { 115 // Volatile qualifier thwarts over-aggressive compiler optimizations 116 // that lead to problems esp. with NaturalLess(). 117 volatile T v1 = w1.Value(); 118 volatile T v2 = w2.Value(); 119 return v1 == v2; 120 } 121 122 inline bool operator==(const FloatWeightTpl<double> &w1, 123 const FloatWeightTpl<double> &w2) { 124 return operator==<double>(w1, w2); 125 } 126 127 inline bool operator==(const FloatWeightTpl<float> &w1, 128 const FloatWeightTpl<float> &w2) { 129 return operator==<float>(w1, w2); 130 } 131 132 template <class T> 133 inline bool operator!=(const FloatWeightTpl<T> &w1, 134 const FloatWeightTpl<T> &w2) { 135 return !(w1 == w2); 136 } 137 138 inline bool operator!=(const FloatWeightTpl<double> &w1, 139 const FloatWeightTpl<double> &w2) { 140 return operator!=<double>(w1, w2); 141 } 142 143 inline bool operator!=(const FloatWeightTpl<float> &w1, 144 const FloatWeightTpl<float> &w2) { 145 return operator!=<float>(w1, w2); 146 } 147 148 template <class T> 149 inline bool ApproxEqual(const FloatWeightTpl<T> &w1, 150 const FloatWeightTpl<T> &w2, 151 float delta = kDelta) { 152 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; 153 } 154 155 template <class T> 156 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { 157 if (w.Value() == FloatLimits<T>::PosInfinity()) 158 return strm << "Infinity"; 159 else if (w.Value() == FloatLimits<T>::NegInfinity()) 160 return strm << "-Infinity"; 161 else if (w.Value() != w.Value()) // Fails for NaN 162 return strm << "BadNumber"; 163 else 164 return strm << w.Value(); 165 } 166 167 template <class T> 168 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { 169 string s; 170 strm >> s; 171 if (s == "Infinity") { 172 w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); 173 } else if (s == "-Infinity") { 174 w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); 175 } else { 176 char *p; 177 T f = strtod(s.c_str(), &p); 178 if (p < s.c_str() + s.size()) 179 strm.clear(std::ios::badbit); 180 else 181 w = FloatWeightTpl<T>(f); 182 } 183 return strm; 184 } 185 186 187 // Tropical semiring: (min, +, inf, 0) 188 template <class T> 189 class TropicalWeightTpl : public FloatWeightTpl<T> { 190 public: 191 using FloatWeightTpl<T>::Value; 192 193 typedef TropicalWeightTpl<T> ReverseWeight; 194 195 TropicalWeightTpl() : FloatWeightTpl<T>() {} 196 197 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} 198 199 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 200 201 static const TropicalWeightTpl<T> Zero() { 202 return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); } 203 204 static const TropicalWeightTpl<T> One() { 205 return TropicalWeightTpl<T>(0.0F); } 206 207 static const TropicalWeightTpl<T> NoWeight() { 208 return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); } 209 210 static const string &Type() { 211 static const string type = "tropical" + 212 FloatWeightTpl<T>::GetPrecisionString(); 213 return type; 214 } 215 216 bool Member() const { 217 // First part fails for IEEE NaN 218 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); 219 } 220 221 TropicalWeightTpl<T> Quantize(float delta = kDelta) const { 222 if (Value() == FloatLimits<T>::NegInfinity() || 223 Value() == FloatLimits<T>::PosInfinity() || 224 Value() != Value()) 225 return *this; 226 else 227 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 228 } 229 230 TropicalWeightTpl<T> Reverse() const { return *this; } 231 232 static uint64 Properties() { 233 return kLeftSemiring | kRightSemiring | kCommutative | 234 kPath | kIdempotent; 235 } 236 }; 237 238 // Single precision tropical weight 239 typedef TropicalWeightTpl<float> TropicalWeight; 240 241 template <class T> 242 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, 243 const TropicalWeightTpl<T> &w2) { 244 if (!w1.Member() || !w2.Member()) 245 return TropicalWeightTpl<T>::NoWeight(); 246 return w1.Value() < w2.Value() ? w1 : w2; 247 } 248 249 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, 250 const TropicalWeightTpl<float> &w2) { 251 return Plus<float>(w1, w2); 252 } 253 254 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, 255 const TropicalWeightTpl<double> &w2) { 256 return Plus<double>(w1, w2); 257 } 258 259 template <class T> 260 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, 261 const TropicalWeightTpl<T> &w2) { 262 if (!w1.Member() || !w2.Member()) 263 return TropicalWeightTpl<T>::NoWeight(); 264 T f1 = w1.Value(), f2 = w2.Value(); 265 if (f1 == FloatLimits<T>::PosInfinity()) 266 return w1; 267 else if (f2 == FloatLimits<T>::PosInfinity()) 268 return w2; 269 else 270 return TropicalWeightTpl<T>(f1 + f2); 271 } 272 273 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, 274 const TropicalWeightTpl<float> &w2) { 275 return Times<float>(w1, w2); 276 } 277 278 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, 279 const TropicalWeightTpl<double> &w2) { 280 return Times<double>(w1, w2); 281 } 282 283 template <class T> 284 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, 285 const TropicalWeightTpl<T> &w2, 286 DivideType typ = DIVIDE_ANY) { 287 if (!w1.Member() || !w2.Member()) 288 return TropicalWeightTpl<T>::NoWeight(); 289 T f1 = w1.Value(), f2 = w2.Value(); 290 if (f2 == FloatLimits<T>::PosInfinity()) 291 return FloatLimits<T>::NumberBad(); 292 else if (f1 == FloatLimits<T>::PosInfinity()) 293 return FloatLimits<T>::PosInfinity(); 294 else 295 return TropicalWeightTpl<T>(f1 - f2); 296 } 297 298 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, 299 const TropicalWeightTpl<float> &w2, 300 DivideType typ = DIVIDE_ANY) { 301 return Divide<float>(w1, w2, typ); 302 } 303 304 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, 305 const TropicalWeightTpl<double> &w2, 306 DivideType typ = DIVIDE_ANY) { 307 return Divide<double>(w1, w2, typ); 308 } 309 310 311 // Log semiring: (log(e^-x + e^y), +, inf, 0) 312 template <class T> 313 class LogWeightTpl : public FloatWeightTpl<T> { 314 public: 315 using FloatWeightTpl<T>::Value; 316 317 typedef LogWeightTpl ReverseWeight; 318 319 LogWeightTpl() : FloatWeightTpl<T>() {} 320 321 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} 322 323 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 324 325 static const LogWeightTpl<T> Zero() { 326 return LogWeightTpl<T>(FloatLimits<T>::PosInfinity()); 327 } 328 329 static const LogWeightTpl<T> One() { 330 return LogWeightTpl<T>(0.0F); 331 } 332 333 static const LogWeightTpl<T> NoWeight() { 334 return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); } 335 336 static const string &Type() { 337 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); 338 return type; 339 } 340 341 bool Member() const { 342 // First part fails for IEEE NaN 343 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); 344 } 345 346 LogWeightTpl<T> Quantize(float delta = kDelta) const { 347 if (Value() == FloatLimits<T>::NegInfinity() || 348 Value() == FloatLimits<T>::PosInfinity() || 349 Value() != Value()) 350 return *this; 351 else 352 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 353 } 354 355 LogWeightTpl<T> Reverse() const { return *this; } 356 357 static uint64 Properties() { 358 return kLeftSemiring | kRightSemiring | kCommutative; 359 } 360 }; 361 362 // Single-precision log weight 363 typedef LogWeightTpl<float> LogWeight; 364 // Double-precision log weight 365 typedef LogWeightTpl<double> Log64Weight; 366 367 template <class T> 368 inline T LogExp(T x) { return log(1.0F + exp(-x)); } 369 370 template <class T> 371 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, 372 const LogWeightTpl<T> &w2) { 373 T f1 = w1.Value(), f2 = w2.Value(); 374 if (f1 == FloatLimits<T>::PosInfinity()) 375 return w2; 376 else if (f2 == FloatLimits<T>::PosInfinity()) 377 return w1; 378 else if (f1 > f2) 379 return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); 380 else 381 return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); 382 } 383 384 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, 385 const LogWeightTpl<float> &w2) { 386 return Plus<float>(w1, w2); 387 } 388 389 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, 390 const LogWeightTpl<double> &w2) { 391 return Plus<double>(w1, w2); 392 } 393 394 template <class T> 395 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, 396 const LogWeightTpl<T> &w2) { 397 if (!w1.Member() || !w2.Member()) 398 return LogWeightTpl<T>::NoWeight(); 399 T f1 = w1.Value(), f2 = w2.Value(); 400 if (f1 == FloatLimits<T>::PosInfinity()) 401 return w1; 402 else if (f2 == FloatLimits<T>::PosInfinity()) 403 return w2; 404 else 405 return LogWeightTpl<T>(f1 + f2); 406 } 407 408 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, 409 const LogWeightTpl<float> &w2) { 410 return Times<float>(w1, w2); 411 } 412 413 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, 414 const LogWeightTpl<double> &w2) { 415 return Times<double>(w1, w2); 416 } 417 418 template <class T> 419 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, 420 const LogWeightTpl<T> &w2, 421 DivideType typ = DIVIDE_ANY) { 422 if (!w1.Member() || !w2.Member()) 423 return LogWeightTpl<T>::NoWeight(); 424 T f1 = w1.Value(), f2 = w2.Value(); 425 if (f2 == FloatLimits<T>::PosInfinity()) 426 return FloatLimits<T>::NumberBad(); 427 else if (f1 == FloatLimits<T>::PosInfinity()) 428 return FloatLimits<T>::PosInfinity(); 429 else 430 return LogWeightTpl<T>(f1 - f2); 431 } 432 433 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, 434 const LogWeightTpl<float> &w2, 435 DivideType typ = DIVIDE_ANY) { 436 return Divide<float>(w1, w2, typ); 437 } 438 439 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, 440 const LogWeightTpl<double> &w2, 441 DivideType typ = DIVIDE_ANY) { 442 return Divide<double>(w1, w2, typ); 443 } 444 445 // MinMax semiring: (min, max, inf, -inf) 446 template <class T> 447 class MinMaxWeightTpl : public FloatWeightTpl<T> { 448 public: 449 using FloatWeightTpl<T>::Value; 450 451 typedef MinMaxWeightTpl<T> ReverseWeight; 452 453 MinMaxWeightTpl() : FloatWeightTpl<T>() {} 454 455 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} 456 457 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 458 459 static const MinMaxWeightTpl<T> Zero() { 460 return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity()); 461 } 462 463 static const MinMaxWeightTpl<T> One() { 464 return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity()); 465 } 466 467 static const MinMaxWeightTpl<T> NoWeight() { 468 return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); } 469 470 static const string &Type() { 471 static const string type = "minmax" + 472 FloatWeightTpl<T>::GetPrecisionString(); 473 return type; 474 } 475 476 bool Member() const { 477 // Fails for IEEE NaN 478 return Value() == Value(); 479 } 480 481 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { 482 // If one of infinities, or a NaN 483 if (Value() == FloatLimits<T>::NegInfinity() || 484 Value() == FloatLimits<T>::PosInfinity() || 485 Value() != Value()) 486 return *this; 487 else 488 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 489 } 490 491 MinMaxWeightTpl<T> Reverse() const { return *this; } 492 493 static uint64 Properties() { 494 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; 495 } 496 }; 497 498 // Single-precision min-max weight 499 typedef MinMaxWeightTpl<float> MinMaxWeight; 500 501 // Min 502 template <class T> 503 inline MinMaxWeightTpl<T> Plus( 504 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 505 if (!w1.Member() || !w2.Member()) 506 return MinMaxWeightTpl<T>::NoWeight(); 507 return w1.Value() < w2.Value() ? w1 : w2; 508 } 509 510 inline MinMaxWeightTpl<float> Plus( 511 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 512 return Plus<float>(w1, w2); 513 } 514 515 inline MinMaxWeightTpl<double> Plus( 516 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 517 return Plus<double>(w1, w2); 518 } 519 520 // Max 521 template <class T> 522 inline MinMaxWeightTpl<T> Times( 523 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 524 if (!w1.Member() || !w2.Member()) 525 return MinMaxWeightTpl<T>::NoWeight(); 526 return w1.Value() >= w2.Value() ? w1 : w2; 527 } 528 529 inline MinMaxWeightTpl<float> Times( 530 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 531 return Times<float>(w1, w2); 532 } 533 534 inline MinMaxWeightTpl<double> Times( 535 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 536 return Times<double>(w1, w2); 537 } 538 539 // Defined only for special cases 540 template <class T> 541 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, 542 const MinMaxWeightTpl<T> &w2, 543 DivideType typ = DIVIDE_ANY) { 544 if (!w1.Member() || !w2.Member()) 545 return MinMaxWeightTpl<T>::NoWeight(); 546 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 547 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad(); 548 } 549 550 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, 551 const MinMaxWeightTpl<float> &w2, 552 DivideType typ = DIVIDE_ANY) { 553 return Divide<float>(w1, w2, typ); 554 } 555 556 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, 557 const MinMaxWeightTpl<double> &w2, 558 DivideType typ = DIVIDE_ANY) { 559 return Divide<double>(w1, w2, typ); 560 } 561 562 // 563 // WEIGHT CONVERTER SPECIALIZATIONS. 564 // 565 566 // Convert to tropical 567 template <> 568 struct WeightConvert<LogWeight, TropicalWeight> { 569 TropicalWeight operator()(LogWeight w) const { return w.Value(); } 570 }; 571 572 template <> 573 struct WeightConvert<Log64Weight, TropicalWeight> { 574 TropicalWeight operator()(Log64Weight w) const { return w.Value(); } 575 }; 576 577 // Convert to log 578 template <> 579 struct WeightConvert<TropicalWeight, LogWeight> { 580 LogWeight operator()(TropicalWeight w) const { return w.Value(); } 581 }; 582 583 template <> 584 struct WeightConvert<Log64Weight, LogWeight> { 585 LogWeight operator()(Log64Weight w) const { return w.Value(); } 586 }; 587 588 // Convert to log64 589 template <> 590 struct WeightConvert<TropicalWeight, Log64Weight> { 591 Log64Weight operator()(TropicalWeight w) const { return w.Value(); } 592 }; 593 594 template <> 595 struct WeightConvert<LogWeight, Log64Weight> { 596 Log64Weight operator()(LogWeight w) const { return w.Value(); } 597 }; 598 599 } // namespace fst 600 601 #endif // FST_LIB_FLOAT_WEIGHT_H__ 602