Home | History | Annotate | Download | only in lib
      1 // float-weight.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Float weight set and associated semiring operation definitions.
     18 //
     19 
     20 #ifndef FST_LIB_FLOAT_WEIGHT_H__
     21 #define FST_LIB_FLOAT_WEIGHT_H__
     22 
     23 #include <limits>
     24 
     25 #include "fst/lib/weight.h"
     26 
     27 namespace fst {
     28 
     29 static const float kPosInfinity = numeric_limits<float>::infinity();
     30 static const float kNegInfinity = -kPosInfinity;
     31 
     32 // Single precision floating point weight base class
     33 class FloatWeight {
     34  public:
     35   FloatWeight() {}
     36 
     37   FloatWeight(float f) : value_(f) {}
     38 
     39   FloatWeight(const FloatWeight &w) : value_(w.value_) {}
     40 
     41   FloatWeight &operator=(const FloatWeight &w) {
     42     value_ = w.value_;
     43     return *this;
     44   }
     45 
     46   istream &Read(istream &strm) {
     47     return ReadType(strm, &value_);
     48   }
     49 
     50   ostream &Write(ostream &strm) const {
     51     return WriteType(strm, value_);
     52   }
     53 
     54   ssize_t Hash() const {
     55     union {
     56       float f;
     57       ssize_t s;
     58     } u = { value_ };
     59     return u.s;
     60   }
     61 
     62   const float &Value() const { return value_; }
     63 
     64  protected:
     65   float value_;
     66 };
     67 
     68 inline bool operator==(const FloatWeight &w1, const FloatWeight &w2) {
     69   // Volatile qualifier thwarts over-aggressive compiler optimizations
     70   // that lead to problems esp. with NaturalLess().
     71   volatile float v1 = w1.Value();
     72   volatile float v2 = w2.Value();
     73   return v1 == v2;
     74 }
     75 
     76 inline bool operator!=(const FloatWeight &w1, const FloatWeight &w2) {
     77   return !(w1 == w2);
     78 }
     79 
     80 inline bool ApproxEqual(const FloatWeight &w1, const FloatWeight &w2,
     81                         float delta = kDelta) {
     82   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
     83 }
     84 
     85 inline ostream &operator<<(ostream &strm, const FloatWeight &w) {
     86   if (w.Value() == kPosInfinity)
     87     return strm << "Infinity";
     88   else if (w.Value() == kNegInfinity)
     89     return strm << "-Infinity";
     90   else if (w.Value() != w.Value())   // Fails for NaN
     91     return strm << "BadFloat";
     92   else
     93     return strm << w.Value();
     94 }
     95 
     96 inline istream &operator>>(istream &strm, FloatWeight &w) {
     97   string s;
     98   strm >> s;
     99   if (s == "Infinity") {
    100     w = FloatWeight(kPosInfinity);
    101   } else if (s == "-Infinity") {
    102     w = FloatWeight(kNegInfinity);
    103   } else {
    104     char *p;
    105     float f = strtod(s.c_str(), &p);
    106     if (p < s.c_str() + s.size())
    107       strm.clear(std::ios::badbit);
    108     else
    109       w = FloatWeight(f);
    110   }
    111   return strm;
    112 }
    113 
    114 
    115 // Tropical semiring: (min, +, inf, 0)
    116 class TropicalWeight : public FloatWeight {
    117  public:
    118   typedef TropicalWeight ReverseWeight;
    119 
    120   TropicalWeight() : FloatWeight() {}
    121 
    122   TropicalWeight(float f) : FloatWeight(f) {}
    123 
    124   TropicalWeight(const TropicalWeight &w) : FloatWeight(w) {}
    125 
    126   static const TropicalWeight Zero() { return TropicalWeight(kPosInfinity); }
    127 
    128   static const TropicalWeight One() { return TropicalWeight(0.0F); }
    129 
    130   static const string &Type() {
    131     static const string type = "tropical";
    132     return type;
    133   }
    134 
    135   bool Member() const {
    136     // First part fails for IEEE NaN
    137     return Value() == Value() && Value() != kNegInfinity;
    138   }
    139 
    140   TropicalWeight Quantize(float delta = kDelta) const {
    141     return TropicalWeight(floor(Value()/delta + 0.5F) * delta);
    142   }
    143 
    144   TropicalWeight Reverse() const { return *this; }
    145 
    146   static uint64 Properties() {
    147     return kLeftSemiring | kRightSemiring | kCommutative |
    148       kPath | kIdempotent;
    149   }
    150 };
    151 
    152 inline TropicalWeight Plus(const TropicalWeight &w1,
    153                            const TropicalWeight &w2) {
    154   return w1.Value() < w2.Value() ? w1 : w2;
    155 }
    156 
    157 inline TropicalWeight Times(const TropicalWeight &w1,
    158                             const TropicalWeight &w2) {
    159   float f1 = w1.Value(), f2 = w2.Value();
    160   if (f1 == kPosInfinity)
    161     return w1;
    162   else if (f2 == kPosInfinity)
    163     return w2;
    164   else
    165     return TropicalWeight(f1 + f2);
    166 }
    167 
    168 inline TropicalWeight Divide(const TropicalWeight &w1,
    169                              const TropicalWeight &w2,
    170                              DivideType typ = DIVIDE_ANY) {
    171   float f1 = w1.Value(), f2 = w2.Value();
    172   if (f2 == kPosInfinity)
    173     return kNegInfinity;
    174   else if (f1 == kPosInfinity)
    175     return kPosInfinity;
    176   else
    177     return TropicalWeight(f1 - f2);
    178 }
    179 
    180 
    181 // Log semiring: (log(e^-x + e^y), +, inf, 0)
    182 class LogWeight : public FloatWeight {
    183  public:
    184   typedef LogWeight ReverseWeight;
    185 
    186   LogWeight() : FloatWeight() {}
    187 
    188   LogWeight(float f) : FloatWeight(f) {}
    189 
    190   LogWeight(const LogWeight &w) : FloatWeight(w) {}
    191 
    192   static const LogWeight Zero() {   return LogWeight(kPosInfinity); }
    193 
    194   static const LogWeight One() { return LogWeight(0.0F); }
    195 
    196   static const string &Type() {
    197     static const string type = "log";
    198     return type;
    199   }
    200 
    201   bool Member() const {
    202     // First part fails for IEEE NaN
    203     return Value() == Value() && Value() != kNegInfinity;
    204   }
    205 
    206   LogWeight Quantize(float delta = kDelta) const {
    207     return LogWeight(floor(Value()/delta + 0.5F) * delta);
    208   }
    209 
    210   LogWeight Reverse() const { return *this; }
    211 
    212   static uint64 Properties() {
    213     return kLeftSemiring | kRightSemiring | kCommutative;
    214   }
    215 };
    216 
    217 inline double LogExp(double x) { return log(1.0F + exp(-x)); }
    218 
    219 inline LogWeight Plus(const LogWeight &w1, const LogWeight &w2) {
    220   float f1 = w1.Value(), f2 = w2.Value();
    221   if (f1 == kPosInfinity)
    222     return w2;
    223   else if (f2 == kPosInfinity)
    224     return w1;
    225   else if (f1 > f2)
    226     return LogWeight(f2 - LogExp(f1 - f2));
    227   else
    228     return LogWeight(f1 - LogExp(f2 - f1));
    229 }
    230 
    231 inline LogWeight Times(const LogWeight &w1, const LogWeight &w2) {
    232   float f1 = w1.Value(), f2 = w2.Value();
    233   if (f1 == kPosInfinity)
    234     return w1;
    235   else if (f2 == kPosInfinity)
    236     return w2;
    237   else
    238     return LogWeight(f1 + f2);
    239 }
    240 
    241 inline LogWeight Divide(const LogWeight &w1,
    242                              const LogWeight &w2,
    243                              DivideType typ = DIVIDE_ANY) {
    244   float f1 = w1.Value(), f2 = w2.Value();
    245   if (f2 == kPosInfinity)
    246     return kNegInfinity;
    247   else if (f1 == kPosInfinity)
    248     return kPosInfinity;
    249   else
    250     return LogWeight(f1 - f2);
    251 }
    252 
    253 }  // namespace fst;
    254 
    255 #endif  // FST_LIB_FLOAT_WEIGHT_H__
    256