Home | History | Annotate | Download | only in fst
      1 // float-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Float weight set and associated semiring operation definitions.
     20 //
     21 
     22 #ifndef FST_LIB_FLOAT_WEIGHT_H__
     23 #define FST_LIB_FLOAT_WEIGHT_H__
     24 
     25 #include <limits>
     26 #include <climits>
     27 #include <sstream>
     28 #include <string>
     29 
     30 #include <fst/util.h>
     31 #include <fst/weight.h>
     32 
     33 
     34 namespace fst {
     35 
     36 // numeric limits class
     37 template <class T>
     38 class FloatLimits {
     39  public:
     40   static const T kPosInfinity;
     41   static const T kNegInfinity;
     42   static const T kNumberBad;
     43 };
     44 
     45 template <class T>
     46 const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity();
     47 
     48 template <class T>
     49 const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity;
     50 
     51 template <class T>
     52 const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN();
     53 
     54 // weight class to be templated on floating-points types
     55 template <class T = float>
     56 class FloatWeightTpl {
     57  public:
     58   FloatWeightTpl() {}
     59 
     60   FloatWeightTpl(T f) : value_(f) {}
     61 
     62   FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
     63 
     64   FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
     65     value_ = w.value_;
     66     return *this;
     67   }
     68 
     69   istream &Read(istream &strm) {
     70     return ReadType(strm, &value_);
     71   }
     72 
     73   ostream &Write(ostream &strm) const {
     74     return WriteType(strm, value_);
     75   }
     76 
     77   size_t Hash() const {
     78     union {
     79       T f;
     80       size_t s;
     81     } u;
     82     u.s = 0;
     83     u.f = value_;
     84     return u.s;
     85   }
     86 
     87   const T &Value() const { return value_; }
     88 
     89  protected:
     90   void SetValue(const T &f) { value_ = f; }
     91 
     92   inline static string GetPrecisionString() {
     93     int64 size = sizeof(T);
     94     if (size == sizeof(float)) return "";
     95     size *= CHAR_BIT;
     96 
     97     string result;
     98     Int64ToStr(size, &result);
     99     return result;
    100   }
    101 
    102  private:
    103   T value_;
    104 };
    105 
    106 // Single-precision float weight
    107 typedef FloatWeightTpl<float> FloatWeight;
    108 
    109 template <class T>
    110 inline bool operator==(const FloatWeightTpl<T> &w1,
    111                        const FloatWeightTpl<T> &w2) {
    112   // Volatile qualifier thwarts over-aggressive compiler optimizations
    113   // that lead to problems esp. with NaturalLess().
    114   volatile T v1 = w1.Value();
    115   volatile T v2 = w2.Value();
    116   return v1 == v2;
    117 }
    118 
    119 inline bool operator==(const FloatWeightTpl<double> &w1,
    120                        const FloatWeightTpl<double> &w2) {
    121   return operator==<double>(w1, w2);
    122 }
    123 
    124 inline bool operator==(const FloatWeightTpl<float> &w1,
    125                        const FloatWeightTpl<float> &w2) {
    126   return operator==<float>(w1, w2);
    127 }
    128 
    129 template <class T>
    130 inline bool operator!=(const FloatWeightTpl<T> &w1,
    131                        const FloatWeightTpl<T> &w2) {
    132   return !(w1 == w2);
    133 }
    134 
    135 inline bool operator!=(const FloatWeightTpl<double> &w1,
    136                        const FloatWeightTpl<double> &w2) {
    137   return operator!=<double>(w1, w2);
    138 }
    139 
    140 inline bool operator!=(const FloatWeightTpl<float> &w1,
    141                        const FloatWeightTpl<float> &w2) {
    142   return operator!=<float>(w1, w2);
    143 }
    144 
    145 template <class T>
    146 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
    147                         const FloatWeightTpl<T> &w2,
    148                         float delta = kDelta) {
    149   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
    150 }
    151 
    152 template <class T>
    153 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
    154   if (w.Value() == FloatLimits<T>::kPosInfinity)
    155     return strm << "Infinity";
    156   else if (w.Value() == FloatLimits<T>::kNegInfinity)
    157     return strm << "-Infinity";
    158   else if (w.Value() != w.Value())   // Fails for NaN
    159     return strm << "BadNumber";
    160   else
    161     return strm << w.Value();
    162 }
    163 
    164 template <class T>
    165 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
    166   string s;
    167   strm >> s;
    168   if (s == "Infinity") {
    169     w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity);
    170   } else if (s == "-Infinity") {
    171     w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity);
    172   } else {
    173     char *p;
    174     T f = strtod(s.c_str(), &p);
    175     if (p < s.c_str() + s.size())
    176       strm.clear(std::ios::badbit);
    177     else
    178       w = FloatWeightTpl<T>(f);
    179   }
    180   return strm;
    181 }
    182 
    183 
    184 // Tropical semiring: (min, +, inf, 0)
    185 template <class T>
    186 class TropicalWeightTpl : public FloatWeightTpl<T> {
    187  public:
    188   using FloatWeightTpl<T>::Value;
    189 
    190   typedef TropicalWeightTpl<T> ReverseWeight;
    191 
    192   TropicalWeightTpl() : FloatWeightTpl<T>() {}
    193 
    194   TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    195 
    196   TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    197 
    198   static const TropicalWeightTpl<T> Zero() {
    199     return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); }
    200 
    201   static const TropicalWeightTpl<T> One() {
    202     return TropicalWeightTpl<T>(0.0F); }
    203 
    204   static const TropicalWeightTpl<T> NoWeight() {
    205     return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); }
    206 
    207   static const string &Type() {
    208     static const string type = "tropical" +
    209         FloatWeightTpl<T>::GetPrecisionString();
    210     return type;
    211   }
    212 
    213   bool Member() const {
    214     // First part fails for IEEE NaN
    215     return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
    216   }
    217 
    218   TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
    219     if (Value() == FloatLimits<T>::kNegInfinity ||
    220         Value() == FloatLimits<T>::kPosInfinity ||
    221         Value() != Value())
    222       return *this;
    223     else
    224       return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    225   }
    226 
    227   TropicalWeightTpl<T> Reverse() const { return *this; }
    228 
    229   static uint64 Properties() {
    230     return kLeftSemiring | kRightSemiring | kCommutative |
    231         kPath | kIdempotent;
    232   }
    233 };
    234 
    235 // Single precision tropical weight
    236 typedef TropicalWeightTpl<float> TropicalWeight;
    237 
    238 template <class T>
    239 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
    240                                  const TropicalWeightTpl<T> &w2) {
    241   if (!w1.Member() || !w2.Member())
    242     return TropicalWeightTpl<T>::NoWeight();
    243   return w1.Value() < w2.Value() ? w1 : w2;
    244 }
    245 
    246 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
    247                                      const TropicalWeightTpl<float> &w2) {
    248   return Plus<float>(w1, w2);
    249 }
    250 
    251 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
    252                                       const TropicalWeightTpl<double> &w2) {
    253   return Plus<double>(w1, w2);
    254 }
    255 
    256 template <class T>
    257 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
    258                                   const TropicalWeightTpl<T> &w2) {
    259   if (!w1.Member() || !w2.Member())
    260     return TropicalWeightTpl<T>::NoWeight();
    261   T f1 = w1.Value(), f2 = w2.Value();
    262   if (f1 == FloatLimits<T>::kPosInfinity)
    263     return w1;
    264   else if (f2 == FloatLimits<T>::kPosInfinity)
    265     return w2;
    266   else
    267     return TropicalWeightTpl<T>(f1 + f2);
    268 }
    269 
    270 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
    271                                       const TropicalWeightTpl<float> &w2) {
    272   return Times<float>(w1, w2);
    273 }
    274 
    275 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
    276                                        const TropicalWeightTpl<double> &w2) {
    277   return Times<double>(w1, w2);
    278 }
    279 
    280 template <class T>
    281 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
    282                                    const TropicalWeightTpl<T> &w2,
    283                                    DivideType typ = DIVIDE_ANY) {
    284   if (!w1.Member() || !w2.Member())
    285     return TropicalWeightTpl<T>::NoWeight();
    286   T f1 = w1.Value(), f2 = w2.Value();
    287   if (f2 == FloatLimits<T>::kPosInfinity)
    288     return FloatLimits<T>::kNumberBad;
    289   else if (f1 == FloatLimits<T>::kPosInfinity)
    290     return FloatLimits<T>::kPosInfinity;
    291   else
    292     return TropicalWeightTpl<T>(f1 - f2);
    293 }
    294 
    295 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
    296                                        const TropicalWeightTpl<float> &w2,
    297                                        DivideType typ = DIVIDE_ANY) {
    298   return Divide<float>(w1, w2, typ);
    299 }
    300 
    301 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
    302                                         const TropicalWeightTpl<double> &w2,
    303                                         DivideType typ = DIVIDE_ANY) {
    304   return Divide<double>(w1, w2, typ);
    305 }
    306 
    307 
    308 // Log semiring: (log(e^-x + e^y), +, inf, 0)
    309 template <class T>
    310 class LogWeightTpl : public FloatWeightTpl<T> {
    311  public:
    312   using FloatWeightTpl<T>::Value;
    313 
    314   typedef LogWeightTpl ReverseWeight;
    315 
    316   LogWeightTpl() : FloatWeightTpl<T>() {}
    317 
    318   LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    319 
    320   LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    321 
    322   static const LogWeightTpl<T> Zero() {
    323     return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity);
    324   }
    325 
    326   static const LogWeightTpl<T> One() {
    327     return LogWeightTpl<T>(0.0F);
    328   }
    329 
    330   static const LogWeightTpl<T> NoWeight() {
    331     return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); }
    332 
    333   static const string &Type() {
    334     static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
    335     return type;
    336   }
    337 
    338   bool Member() const {
    339     // First part fails for IEEE NaN
    340     return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
    341   }
    342 
    343   LogWeightTpl<T> Quantize(float delta = kDelta) const {
    344     if (Value() == FloatLimits<T>::kNegInfinity ||
    345         Value() == FloatLimits<T>::kPosInfinity ||
    346         Value() != Value())
    347       return *this;
    348     else
    349       return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    350   }
    351 
    352   LogWeightTpl<T> Reverse() const { return *this; }
    353 
    354   static uint64 Properties() {
    355     return kLeftSemiring | kRightSemiring | kCommutative;
    356   }
    357 };
    358 
    359 // Single-precision log weight
    360 typedef LogWeightTpl<float> LogWeight;
    361 // Double-precision log weight
    362 typedef LogWeightTpl<double> Log64Weight;
    363 
    364 template <class T>
    365 inline T LogExp(T x) { return log(1.0F + exp(-x)); }
    366 
    367 template <class T>
    368 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
    369                             const LogWeightTpl<T> &w2) {
    370   T f1 = w1.Value(), f2 = w2.Value();
    371   if (f1 == FloatLimits<T>::kPosInfinity)
    372     return w2;
    373   else if (f2 == FloatLimits<T>::kPosInfinity)
    374     return w1;
    375   else if (f1 > f2)
    376     return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
    377   else
    378     return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
    379 }
    380 
    381 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
    382                                 const LogWeightTpl<float> &w2) {
    383   return Plus<float>(w1, w2);
    384 }
    385 
    386 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
    387                                  const LogWeightTpl<double> &w2) {
    388   return Plus<double>(w1, w2);
    389 }
    390 
    391 template <class T>
    392 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
    393                              const LogWeightTpl<T> &w2) {
    394   if (!w1.Member() || !w2.Member())
    395     return LogWeightTpl<T>::NoWeight();
    396   T f1 = w1.Value(), f2 = w2.Value();
    397   if (f1 == FloatLimits<T>::kPosInfinity)
    398     return w1;
    399   else if (f2 == FloatLimits<T>::kPosInfinity)
    400     return w2;
    401   else
    402     return LogWeightTpl<T>(f1 + f2);
    403 }
    404 
    405 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
    406                                  const LogWeightTpl<float> &w2) {
    407   return Times<float>(w1, w2);
    408 }
    409 
    410 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
    411                                   const LogWeightTpl<double> &w2) {
    412   return Times<double>(w1, w2);
    413 }
    414 
    415 template <class T>
    416 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
    417                               const LogWeightTpl<T> &w2,
    418                               DivideType typ = DIVIDE_ANY) {
    419   if (!w1.Member() || !w2.Member())
    420     return LogWeightTpl<T>::NoWeight();
    421   T f1 = w1.Value(), f2 = w2.Value();
    422   if (f2 == FloatLimits<T>::kPosInfinity)
    423     return FloatLimits<T>::kNumberBad;
    424   else if (f1 == FloatLimits<T>::kPosInfinity)
    425     return FloatLimits<T>::kPosInfinity;
    426   else
    427     return LogWeightTpl<T>(f1 - f2);
    428 }
    429 
    430 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
    431                                   const LogWeightTpl<float> &w2,
    432                                   DivideType typ = DIVIDE_ANY) {
    433   return Divide<float>(w1, w2, typ);
    434 }
    435 
    436 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
    437                                    const LogWeightTpl<double> &w2,
    438                                    DivideType typ = DIVIDE_ANY) {
    439   return Divide<double>(w1, w2, typ);
    440 }
    441 
    442 // MinMax semiring: (min, max, inf, -inf)
    443 template <class T>
    444 class MinMaxWeightTpl : public FloatWeightTpl<T> {
    445  public:
    446   using FloatWeightTpl<T>::Value;
    447 
    448   typedef MinMaxWeightTpl<T> ReverseWeight;
    449 
    450   MinMaxWeightTpl() : FloatWeightTpl<T>() {}
    451 
    452   MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
    453 
    454   MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
    455 
    456   static const MinMaxWeightTpl<T> Zero() {
    457     return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity);
    458   }
    459 
    460   static const MinMaxWeightTpl<T> One() {
    461     return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity);
    462   }
    463 
    464   static const MinMaxWeightTpl<T> NoWeight() {
    465     return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); }
    466 
    467   static const string &Type() {
    468     static const string type = "minmax" +
    469         FloatWeightTpl<T>::GetPrecisionString();
    470     return type;
    471   }
    472 
    473   bool Member() const {
    474     // Fails for IEEE NaN
    475     return Value() == Value();
    476   }
    477 
    478   MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
    479     // If one of infinities, or a NaN
    480     if (Value() == FloatLimits<T>::kNegInfinity ||
    481         Value() == FloatLimits<T>::kPosInfinity ||
    482         Value() != Value())
    483       return *this;
    484     else
    485       return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
    486   }
    487 
    488   MinMaxWeightTpl<T> Reverse() const { return *this; }
    489 
    490   static uint64 Properties() {
    491     return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
    492   }
    493 };
    494 
    495 // Single-precision min-max weight
    496 typedef MinMaxWeightTpl<float> MinMaxWeight;
    497 
    498 // Min
    499 template <class T>
    500 inline MinMaxWeightTpl<T> Plus(
    501     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
    502   if (!w1.Member() || !w2.Member())
    503     return MinMaxWeightTpl<T>::NoWeight();
    504   return w1.Value() < w2.Value() ? w1 : w2;
    505 }
    506 
    507 inline MinMaxWeightTpl<float> Plus(
    508     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
    509   return Plus<float>(w1, w2);
    510 }
    511 
    512 inline MinMaxWeightTpl<double> Plus(
    513     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
    514   return Plus<double>(w1, w2);
    515 }
    516 
    517 // Max
    518 template <class T>
    519 inline MinMaxWeightTpl<T> Times(
    520     const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
    521   if (!w1.Member() || !w2.Member())
    522     return MinMaxWeightTpl<T>::NoWeight();
    523   return w1.Value() >= w2.Value() ? w1 : w2;
    524 }
    525 
    526 inline MinMaxWeightTpl<float> Times(
    527     const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
    528   return Times<float>(w1, w2);
    529 }
    530 
    531 inline MinMaxWeightTpl<double> Times(
    532     const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
    533   return Times<double>(w1, w2);
    534 }
    535 
    536 // Defined only for special cases
    537 template <class T>
    538 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
    539                                  const MinMaxWeightTpl<T> &w2,
    540                                  DivideType typ = DIVIDE_ANY) {
    541   if (!w1.Member() || !w2.Member())
    542     return MinMaxWeightTpl<T>::NoWeight();
    543   // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
    544   return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad;
    545 }
    546 
    547 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
    548                                      const MinMaxWeightTpl<float> &w2,
    549                                      DivideType typ = DIVIDE_ANY) {
    550   return Divide<float>(w1, w2, typ);
    551 }
    552 
    553 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
    554                                       const MinMaxWeightTpl<double> &w2,
    555                                       DivideType typ = DIVIDE_ANY) {
    556   return Divide<double>(w1, w2, typ);
    557 }
    558 
    559 //
    560 // WEIGHT CONVERTER SPECIALIZATIONS.
    561 //
    562 
    563 // Convert to tropical
    564 template <>
    565 struct WeightConvert<LogWeight, TropicalWeight> {
    566   TropicalWeight operator()(LogWeight w) const { return w.Value(); }
    567 };
    568 
    569 template <>
    570 struct WeightConvert<Log64Weight, TropicalWeight> {
    571   TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
    572 };
    573 
    574 // Convert to log
    575 template <>
    576 struct WeightConvert<TropicalWeight, LogWeight> {
    577   LogWeight operator()(TropicalWeight w) const { return w.Value(); }
    578 };
    579 
    580 template <>
    581 struct WeightConvert<Log64Weight, LogWeight> {
    582   LogWeight operator()(Log64Weight w) const { return w.Value(); }
    583 };
    584 
    585 // Convert to log64
    586 template <>
    587 struct WeightConvert<TropicalWeight, Log64Weight> {
    588   Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
    589 };
    590 
    591 template <>
    592 struct WeightConvert<LogWeight, Log64Weight> {
    593   Log64Weight operator()(LogWeight w) const { return w.Value(); }
    594 };
    595 
    596 }  // namespace fst
    597 
    598 #endif  // FST_LIB_FLOAT_WEIGHT_H__
    599