1 // float-weight.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Float weight set and associated semiring operation definitions. 18 // 19 20 #ifndef FST_LIB_FLOAT_WEIGHT_H__ 21 #define FST_LIB_FLOAT_WEIGHT_H__ 22 23 #include <limits> 24 25 #include "fst/lib/weight.h" 26 27 namespace fst { 28 29 static const float kPosInfinity = numeric_limits<float>::infinity(); 30 static const float kNegInfinity = -kPosInfinity; 31 32 // Single precision floating point weight base class 33 class FloatWeight { 34 public: 35 FloatWeight() {} 36 37 FloatWeight(float f) : value_(f) {} 38 39 FloatWeight(const FloatWeight &w) : value_(w.value_) {} 40 41 FloatWeight &operator=(const FloatWeight &w) { 42 value_ = w.value_; 43 return *this; 44 } 45 46 istream &Read(istream &strm) { 47 return ReadType(strm, &value_); 48 } 49 50 ostream &Write(ostream &strm) const { 51 return WriteType(strm, value_); 52 } 53 54 ssize_t Hash() const { 55 union { 56 float f; 57 ssize_t s; 58 } u = { value_ }; 59 return u.s; 60 } 61 62 const float &Value() const { return value_; } 63 64 protected: 65 float value_; 66 }; 67 68 inline bool operator==(const FloatWeight &w1, const FloatWeight &w2) { 69 // Volatile qualifier thwarts over-aggressive compiler optimizations 70 // that lead to problems esp. with NaturalLess(). 71 volatile float v1 = w1.Value(); 72 volatile float v2 = w2.Value(); 73 return v1 == v2; 74 } 75 76 inline bool operator!=(const FloatWeight &w1, const FloatWeight &w2) { 77 return !(w1 == w2); 78 } 79 80 inline bool ApproxEqual(const FloatWeight &w1, const FloatWeight &w2, 81 float delta = kDelta) { 82 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; 83 } 84 85 inline ostream &operator<<(ostream &strm, const FloatWeight &w) { 86 if (w.Value() == kPosInfinity) 87 return strm << "Infinity"; 88 else if (w.Value() == kNegInfinity) 89 return strm << "-Infinity"; 90 else if (w.Value() != w.Value()) // Fails for NaN 91 return strm << "BadFloat"; 92 else 93 return strm << w.Value(); 94 } 95 96 inline istream &operator>>(istream &strm, FloatWeight &w) { 97 string s; 98 strm >> s; 99 if (s == "Infinity") { 100 w = FloatWeight(kPosInfinity); 101 } else if (s == "-Infinity") { 102 w = FloatWeight(kNegInfinity); 103 } else { 104 char *p; 105 float f = strtod(s.c_str(), &p); 106 if (p < s.c_str() + s.size()) 107 strm.clear(std::ios::badbit); 108 else 109 w = FloatWeight(f); 110 } 111 return strm; 112 } 113 114 115 // Tropical semiring: (min, +, inf, 0) 116 class TropicalWeight : public FloatWeight { 117 public: 118 typedef TropicalWeight ReverseWeight; 119 120 TropicalWeight() : FloatWeight() {} 121 122 TropicalWeight(float f) : FloatWeight(f) {} 123 124 TropicalWeight(const TropicalWeight &w) : FloatWeight(w) {} 125 126 static const TropicalWeight Zero() { return TropicalWeight(kPosInfinity); } 127 128 static const TropicalWeight One() { return TropicalWeight(0.0F); } 129 130 static const string &Type() { 131 static const string type = "tropical"; 132 return type; 133 } 134 135 bool Member() const { 136 // First part fails for IEEE NaN 137 return Value() == Value() && Value() != kNegInfinity; 138 } 139 140 TropicalWeight Quantize(float delta = kDelta) const { 141 return TropicalWeight(floor(Value()/delta + 0.5F) * delta); 142 } 143 144 TropicalWeight Reverse() const { return *this; } 145 146 static uint64 Properties() { 147 return kLeftSemiring | kRightSemiring | kCommutative | 148 kPath | kIdempotent; 149 } 150 }; 151 152 inline TropicalWeight Plus(const TropicalWeight &w1, 153 const TropicalWeight &w2) { 154 return w1.Value() < w2.Value() ? w1 : w2; 155 } 156 157 inline TropicalWeight Times(const TropicalWeight &w1, 158 const TropicalWeight &w2) { 159 float f1 = w1.Value(), f2 = w2.Value(); 160 if (f1 == kPosInfinity) 161 return w1; 162 else if (f2 == kPosInfinity) 163 return w2; 164 else 165 return TropicalWeight(f1 + f2); 166 } 167 168 inline TropicalWeight Divide(const TropicalWeight &w1, 169 const TropicalWeight &w2, 170 DivideType typ = DIVIDE_ANY) { 171 float f1 = w1.Value(), f2 = w2.Value(); 172 if (f2 == kPosInfinity) 173 return kNegInfinity; 174 else if (f1 == kPosInfinity) 175 return kPosInfinity; 176 else 177 return TropicalWeight(f1 - f2); 178 } 179 180 181 // Log semiring: (log(e^-x + e^y), +, inf, 0) 182 class LogWeight : public FloatWeight { 183 public: 184 typedef LogWeight ReverseWeight; 185 186 LogWeight() : FloatWeight() {} 187 188 LogWeight(float f) : FloatWeight(f) {} 189 190 LogWeight(const LogWeight &w) : FloatWeight(w) {} 191 192 static const LogWeight Zero() { return LogWeight(kPosInfinity); } 193 194 static const LogWeight One() { return LogWeight(0.0F); } 195 196 static const string &Type() { 197 static const string type = "log"; 198 return type; 199 } 200 201 bool Member() const { 202 // First part fails for IEEE NaN 203 return Value() == Value() && Value() != kNegInfinity; 204 } 205 206 LogWeight Quantize(float delta = kDelta) const { 207 return LogWeight(floor(Value()/delta + 0.5F) * delta); 208 } 209 210 LogWeight Reverse() const { return *this; } 211 212 static uint64 Properties() { 213 return kLeftSemiring | kRightSemiring | kCommutative; 214 } 215 }; 216 217 inline double LogExp(double x) { return log(1.0F + exp(-x)); } 218 219 inline LogWeight Plus(const LogWeight &w1, const LogWeight &w2) { 220 float f1 = w1.Value(), f2 = w2.Value(); 221 if (f1 == kPosInfinity) 222 return w2; 223 else if (f2 == kPosInfinity) 224 return w1; 225 else if (f1 > f2) 226 return LogWeight(f2 - LogExp(f1 - f2)); 227 else 228 return LogWeight(f1 - LogExp(f2 - f1)); 229 } 230 231 inline LogWeight Times(const LogWeight &w1, const LogWeight &w2) { 232 float f1 = w1.Value(), f2 = w2.Value(); 233 if (f1 == kPosInfinity) 234 return w1; 235 else if (f2 == kPosInfinity) 236 return w2; 237 else 238 return LogWeight(f1 + f2); 239 } 240 241 inline LogWeight Divide(const LogWeight &w1, 242 const LogWeight &w2, 243 DivideType typ = DIVIDE_ANY) { 244 float f1 = w1.Value(), f2 = w2.Value(); 245 if (f2 == kPosInfinity) 246 return kNegInfinity; 247 else if (f1 == kPosInfinity) 248 return kPosInfinity; 249 else 250 return LogWeight(f1 - f2); 251 } 252 253 } // namespace fst; 254 255 #endif // FST_LIB_FLOAT_WEIGHT_H__ 256