Home | History | Annotate | Download | only in fst
      1 // float-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Float weight set and associated semiring operation definitions.
     20 //
     21 
     22 #ifndef FST_LIB_FLOAT_WEIGHT_H__
     23 #define FST_LIB_FLOAT_WEIGHT_H__
     24 
     25 #include <limits>
     26 #include <climits>
     27 #include <sstream>
     28 #include <string>
     29 
     30 #include <fst/util.h>
     31 #include <fst/weight.h>
     32 
     33 
     34 namespace fst {
     35 
     36 // numeric limits class
     37 template <class T>
     38 class FloatLimits {
     39  public:
     40   static const T PosInfinity() {
     41     static const T pos_infinity = numeric_limits<T>::infinity();
     42     return pos_infinity;
     43   }
     44 
     45   static const T NegInfinity() {
     46     static const T neg_infinity = -PosInfinity();
     47     return neg_infinity;
     48   }
     49 
     50   static const T NumberBad() {
     51     static const T number_bad = numeric_limits<T>::quiet_NaN();
     52     return number_bad;
     53   }
     54 
     55 };
     56 
     57 // weight class to be templated on floating-points types
     58 template <class T = float>
     59 class FloatWeightTpl {
     60  public:
     61   FloatWeightTpl() {}
     62 
     63   FloatWeightTpl(T f) : value_(f) {}
     64 
     65   FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
     66 
     67   FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
     68     value_ = w.value_;
     69     return *this;
     70   }
     71 
     72   istream &Read(istream &strm) {
     73     return ReadType(strm, &value_);
     74   }
     75 
     76   ostream &Write(ostream &strm) const {
     77     return WriteType(strm, value_);
     78   }
     79 
     80   size_t Hash() const {
     81     union {
     82       T f;
     83       size_t s;
     84     } u;
     85     u.s = 0;
     86     u.f = value_;
     87     return u.s;
     88   }
     89 
     90   const T &Value() const { return value_; }
     91 
     92  protected:
     93   void SetValue(const T &f) { value_ = f; }
     94 
     95   inline static string GetPrecisionString() {
     96     int64 size = sizeof(T);
     97     if (size == sizeof(float)) return "";
     98     size *= CHAR_BIT;
     99 
    100     string result;
    101     Int64ToStr(size, &result);
    102     return result;
    103   }
    104 
    105  private:
    106   T value_;
    107 };
    108 
    109 // Single-precision float weight
    110 typedef FloatWeightTpl<float> FloatWeight;
    111 
    112 template <class T>
    113 inline bool operator==(const FloatWeightTpl<T> &w1,
    114                        const FloatWeightTpl<T> &w2) {
    115   // Volatile qualifier thwarts over-aggressive compiler optimizations
    116   // that lead to problems esp. with NaturalLess().
    117   volatile T v1 = w1.Value();
    118   volatile T v2 = w2.Value();
    119   return v1 == v2;
    120 }
    121 
    122 inline bool operator==(const FloatWeightTpl<double> &w1,
    123                        const FloatWeightTpl<double> &w2) {
    124   return operator==<double>(w1, w2);
    125 }
    126 
    127 inline bool operator==(const FloatWeightTpl<float> &w1,
    128                        const FloatWeightTpl<float> &w2) {
    129   return operator==<float>(w1, w2);
    130 }
    131 
    132 template <class T>
    133 inline bool operator!=(const FloatWeightTpl<T> &w1,
    134                        const FloatWeightTpl<T> &w2) {
    135   return !(w1 == w2);
    136 }
    137 
    138 inline bool operator!=(const FloatWeightTpl<double> &w1,
    139                        const FloatWeightTpl<double> &w2) {
    140   return operator!=<double>(w1, w2);
    141 }
    142 
    143 inline bool operator!=(const FloatWeightTpl<float> &w1,
    144                        const FloatWeightTpl<float> &w2) {
    145   return operator!=<float>(w1, w2);
    146 }
    147 
    148 template <class T>
    149 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
    150                         const FloatWeightTpl<T> &w2,
    151                         float delta = kDelta) {
    152   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
    153 }
    154 
    155 template <class T>
    156 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
    157   if (w.Value() == FloatLimits<T>::PosInfinity())
    158     return strm << "Infinity";
    159   else if (w.Value() == FloatLimits<T>::NegInfinity())
    160     return strm << "-Infinity";
    161   else if (w.Value() != w.Value())   // Fails for NaN
    162     return strm << "BadNumber";
    163   else
    164     return strm << w.Value();
    165 }
    166 
    167 template <class T>
    168 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
    169   string s;
    170   strm >> s;
    171   if (s == "Infinity") {
    172     w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
    173   } else if (s == "-Infinity") {
    174     w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
    175   } else {
    176     char *p;
    177     T f = strtod(s.c_str(), &p);
    178     if (p < s.c_str() + s.size())
    179       strm.clear(std::ios::badbit);
    180     else
    181       w = FloatWeightTpl<T>(f);
    182   }
    183   return strm;
    184 }
    185 
    186 
    187 // Tropical semiring: (min, +, inf, 0)
    188 template <class T>
    189 class TropicalWeightTpl : public FloatWeightTpl<T> {
    190  public:
    191   using FloatWeightTpl<T>::Value;
    192 
    193   typedef TropicalWeightTpl<T> ReverseWeight;
    194 
    195   TropicalWeightTpl() : FloatWeightTpl<T>() {}
    196 
    197   TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    198 
    199   TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    200 
    201   static const TropicalWeightTpl<T> Zero() {
    202     return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); }
    203 
    204   static const TropicalWeightTpl<T> One() {
    205     return TropicalWeightTpl<T>(0.0F); }
    206 
    207   static const TropicalWeightTpl<T> NoWeight() {
    208     return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); }
    209 
    210   static const string &Type() {
    211     static const string type = "tropical" +
    212         FloatWeightTpl<T>::GetPrecisionString();
    213     return type;
    214   }
    215 
    216   bool Member() const {
    217     // First part fails for IEEE NaN
    218     return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
    219   }
    220 
    221   TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
    222     if (Value() == FloatLimits<T>::NegInfinity() ||
    223         Value() == FloatLimits<T>::PosInfinity() ||
    224         Value() != Value())
    225       return *this;
    226     else
    227       return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    228   }
    229 
    230   TropicalWeightTpl<T> Reverse() const { return *this; }
    231 
    232   static uint64 Properties() {
    233     return kLeftSemiring | kRightSemiring | kCommutative |
    234         kPath | kIdempotent;
    235   }
    236 };
    237 
    238 // Single precision tropical weight
    239 typedef TropicalWeightTpl<float> TropicalWeight;
    240 
    241 template <class T>
    242 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
    243                                  const TropicalWeightTpl<T> &w2) {
    244   if (!w1.Member() || !w2.Member())
    245     return TropicalWeightTpl<T>::NoWeight();
    246   return w1.Value() < w2.Value() ? w1 : w2;
    247 }
    248 
    249 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
    250                                      const TropicalWeightTpl<float> &w2) {
    251   return Plus<float>(w1, w2);
    252 }
    253 
    254 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
    255                                       const TropicalWeightTpl<double> &w2) {
    256   return Plus<double>(w1, w2);
    257 }
    258 
    259 template <class T>
    260 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
    261                                   const TropicalWeightTpl<T> &w2) {
    262   if (!w1.Member() || !w2.Member())
    263     return TropicalWeightTpl<T>::NoWeight();
    264   T f1 = w1.Value(), f2 = w2.Value();
    265   if (f1 == FloatLimits<T>::PosInfinity())
    266     return w1;
    267   else if (f2 == FloatLimits<T>::PosInfinity())
    268     return w2;
    269   else
    270     return TropicalWeightTpl<T>(f1 + f2);
    271 }
    272 
    273 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
    274                                       const TropicalWeightTpl<float> &w2) {
    275   return Times<float>(w1, w2);
    276 }
    277 
    278 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
    279                                        const TropicalWeightTpl<double> &w2) {
    280   return Times<double>(w1, w2);
    281 }
    282 
    283 template <class T>
    284 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
    285                                    const TropicalWeightTpl<T> &w2,
    286                                    DivideType typ = DIVIDE_ANY) {
    287   if (!w1.Member() || !w2.Member())
    288     return TropicalWeightTpl<T>::NoWeight();
    289   T f1 = w1.Value(), f2 = w2.Value();
    290   if (f2 == FloatLimits<T>::PosInfinity())
    291     return FloatLimits<T>::NumberBad();
    292   else if (f1 == FloatLimits<T>::PosInfinity())
    293     return FloatLimits<T>::PosInfinity();
    294   else
    295     return TropicalWeightTpl<T>(f1 - f2);
    296 }
    297 
    298 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
    299                                        const TropicalWeightTpl<float> &w2,
    300                                        DivideType typ = DIVIDE_ANY) {
    301   return Divide<float>(w1, w2, typ);
    302 }
    303 
    304 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
    305                                         const TropicalWeightTpl<double> &w2,
    306                                         DivideType typ = DIVIDE_ANY) {
    307   return Divide<double>(w1, w2, typ);
    308 }
    309 
    310 
    311 // Log semiring: (log(e^-x + e^y), +, inf, 0)
    312 template <class T>
    313 class LogWeightTpl : public FloatWeightTpl<T> {
    314  public:
    315   using FloatWeightTpl<T>::Value;
    316 
    317   typedef LogWeightTpl ReverseWeight;
    318 
    319   LogWeightTpl() : FloatWeightTpl<T>() {}
    320 
    321   LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    322 
    323   LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    324 
    325   static const LogWeightTpl<T> Zero() {
    326     return LogWeightTpl<T>(FloatLimits<T>::PosInfinity());
    327   }
    328 
    329   static const LogWeightTpl<T> One() {
    330     return LogWeightTpl<T>(0.0F);
    331   }
    332 
    333   static const LogWeightTpl<T> NoWeight() {
    334     return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); }
    335 
    336   static const string &Type() {
    337     static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
    338     return type;
    339   }
    340 
    341   bool Member() const {
    342     // First part fails for IEEE NaN
    343     return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
    344   }
    345 
    346   LogWeightTpl<T> Quantize(float delta = kDelta) const {
    347     if (Value() == FloatLimits<T>::NegInfinity() ||
    348         Value() == FloatLimits<T>::PosInfinity() ||
    349         Value() != Value())
    350       return *this;
    351     else
    352       return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    353   }
    354 
    355   LogWeightTpl<T> Reverse() const { return *this; }
    356 
    357   static uint64 Properties() {
    358     return kLeftSemiring | kRightSemiring | kCommutative;
    359   }
    360 };
    361 
    362 // Single-precision log weight
    363 typedef LogWeightTpl<float> LogWeight;
    364 // Double-precision log weight
    365 typedef LogWeightTpl<double> Log64Weight;
    366 
    367 template <class T>
    368 inline T LogExp(T x) { return log(1.0F + exp(-x)); }
    369 
    370 template <class T>
    371 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
    372                             const LogWeightTpl<T> &w2) {
    373   T f1 = w1.Value(), f2 = w2.Value();
    374   if (f1 == FloatLimits<T>::PosInfinity())
    375     return w2;
    376   else if (f2 == FloatLimits<T>::PosInfinity())
    377     return w1;
    378   else if (f1 > f2)
    379     return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
    380   else
    381     return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
    382 }
    383 
    384 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
    385                                 const LogWeightTpl<float> &w2) {
    386   return Plus<float>(w1, w2);
    387 }
    388 
    389 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
    390                                  const LogWeightTpl<double> &w2) {
    391   return Plus<double>(w1, w2);
    392 }
    393 
    394 template <class T>
    395 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
    396                              const LogWeightTpl<T> &w2) {
    397   if (!w1.Member() || !w2.Member())
    398     return LogWeightTpl<T>::NoWeight();
    399   T f1 = w1.Value(), f2 = w2.Value();
    400   if (f1 == FloatLimits<T>::PosInfinity())
    401     return w1;
    402   else if (f2 == FloatLimits<T>::PosInfinity())
    403     return w2;
    404   else
    405     return LogWeightTpl<T>(f1 + f2);
    406 }
    407 
    408 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
    409                                  const LogWeightTpl<float> &w2) {
    410   return Times<float>(w1, w2);
    411 }
    412 
    413 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
    414                                   const LogWeightTpl<double> &w2) {
    415   return Times<double>(w1, w2);
    416 }
    417 
    418 template <class T>
    419 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
    420                               const LogWeightTpl<T> &w2,
    421                               DivideType typ = DIVIDE_ANY) {
    422   if (!w1.Member() || !w2.Member())
    423     return LogWeightTpl<T>::NoWeight();
    424   T f1 = w1.Value(), f2 = w2.Value();
    425   if (f2 == FloatLimits<T>::PosInfinity())
    426     return FloatLimits<T>::NumberBad();
    427   else if (f1 == FloatLimits<T>::PosInfinity())
    428     return FloatLimits<T>::PosInfinity();
    429   else
    430     return LogWeightTpl<T>(f1 - f2);
    431 }
    432 
    433 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
    434                                   const LogWeightTpl<float> &w2,
    435                                   DivideType typ = DIVIDE_ANY) {
    436   return Divide<float>(w1, w2, typ);
    437 }
    438 
    439 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
    440                                    const LogWeightTpl<double> &w2,
    441                                    DivideType typ = DIVIDE_ANY) {
    442   return Divide<double>(w1, w2, typ);
    443 }
    444 
    445 // MinMax semiring: (min, max, inf, -inf)
    446 template <class T>
    447 class MinMaxWeightTpl : public FloatWeightTpl<T> {
    448  public:
    449   using FloatWeightTpl<T>::Value;
    450 
    451   typedef MinMaxWeightTpl<T> ReverseWeight;
    452 
    453   MinMaxWeightTpl() : FloatWeightTpl<T>() {}
    454 
    455   MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    456 
    457   MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    458 
    459   static const MinMaxWeightTpl<T> Zero() {
    460     return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity());
    461   }
    462 
    463   static const MinMaxWeightTpl<T> One() {
    464     return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity());
    465   }
    466 
    467   static const MinMaxWeightTpl<T> NoWeight() {
    468     return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); }
    469 
    470   static const string &Type() {
    471     static const string type = "minmax" +
    472         FloatWeightTpl<T>::GetPrecisionString();
    473     return type;
    474   }
    475 
    476   bool Member() const {
    477     // Fails for IEEE NaN
    478     return Value() == Value();
    479   }
    480 
    481   MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
    482     // If one of infinities, or a NaN
    483     if (Value() == FloatLimits<T>::NegInfinity() ||
    484         Value() == FloatLimits<T>::PosInfinity() ||
    485         Value() != Value())
    486       return *this;
    487     else
    488       return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    489   }
    490 
    491   MinMaxWeightTpl<T> Reverse() const { return *this; }
    492 
    493   static uint64 Properties() {
    494     return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
    495   }
    496 };
    497 
    498 // Single-precision min-max weight
    499 typedef MinMaxWeightTpl<float> MinMaxWeight;
    500 
    501 // Min
    502 template <class T>
    503 inline MinMaxWeightTpl<T> Plus(
    504     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
    505   if (!w1.Member() || !w2.Member())
    506     return MinMaxWeightTpl<T>::NoWeight();
    507   return w1.Value() < w2.Value() ? w1 : w2;
    508 }
    509 
    510 inline MinMaxWeightTpl<float> Plus(
    511     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
    512   return Plus<float>(w1, w2);
    513 }
    514 
    515 inline MinMaxWeightTpl<double> Plus(
    516     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
    517   return Plus<double>(w1, w2);
    518 }
    519 
    520 // Max
    521 template <class T>
    522 inline MinMaxWeightTpl<T> Times(
    523     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
    524   if (!w1.Member() || !w2.Member())
    525     return MinMaxWeightTpl<T>::NoWeight();
    526   return w1.Value() >= w2.Value() ? w1 : w2;
    527 }
    528 
    529 inline MinMaxWeightTpl<float> Times(
    530     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
    531   return Times<float>(w1, w2);
    532 }
    533 
    534 inline MinMaxWeightTpl<double> Times(
    535     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
    536   return Times<double>(w1, w2);
    537 }
    538 
    539 // Defined only for special cases
    540 template <class T>
    541 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
    542                                  const MinMaxWeightTpl<T> &w2,
    543                                  DivideType typ = DIVIDE_ANY) {
    544   if (!w1.Member() || !w2.Member())
    545     return MinMaxWeightTpl<T>::NoWeight();
    546   // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
    547   return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
    548 }
    549 
    550 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
    551                                      const MinMaxWeightTpl<float> &w2,
    552                                      DivideType typ = DIVIDE_ANY) {
    553   return Divide<float>(w1, w2, typ);
    554 }
    555 
    556 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
    557                                       const MinMaxWeightTpl<double> &w2,
    558                                       DivideType typ = DIVIDE_ANY) {
    559   return Divide<double>(w1, w2, typ);
    560 }
    561 
    562 //
    563 // WEIGHT CONVERTER SPECIALIZATIONS.
    564 //
    565 
    566 // Convert to tropical
    567 template <>
    568 struct WeightConvert<LogWeight, TropicalWeight> {
    569   TropicalWeight operator()(LogWeight w) const { return w.Value(); }
    570 };
    571 
    572 template <>
    573 struct WeightConvert<Log64Weight, TropicalWeight> {
    574   TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
    575 };
    576 
    577 // Convert to log
    578 template <>
    579 struct WeightConvert<TropicalWeight, LogWeight> {
    580   LogWeight operator()(TropicalWeight w) const { return w.Value(); }
    581 };
    582 
    583 template <>
    584 struct WeightConvert<Log64Weight, LogWeight> {
    585   LogWeight operator()(Log64Weight w) const { return w.Value(); }
    586 };
    587 
    588 // Convert to log64
    589 template <>
    590 struct WeightConvert<TropicalWeight, Log64Weight> {
    591   Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
    592 };
    593 
    594 template <>
    595 struct WeightConvert<LogWeight, Log64Weight> {
    596   Log64Weight operator()(LogWeight w) const { return w.Value(); }
    597 };
    598 
    599 }  // namespace fst
    600 
    601 #endif  // FST_LIB_FLOAT_WEIGHT_H__
    602