1 // Copyright (c) 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Contains utils for reading, writing and debug printing bit streams. 16 17 #ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 18 #define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 19 20 #include <algorithm> 21 #include <cassert> 22 #include <functional> 23 #include <queue> 24 #include <iomanip> 25 #include <map> 26 #include <memory> 27 #include <ostream> 28 #include <sstream> 29 #include <stack> 30 #include <tuple> 31 #include <unordered_map> 32 #include <vector> 33 34 namespace spvutils { 35 36 // Used to generate and apply a Huffman coding scheme. 37 // |Val| is the type of variable being encoded (for example a string or a 38 // literal). 39 template <class Val> 40 class HuffmanCodec { 41 public: 42 // Huffman tree node. 43 struct Node { 44 Node() {} 45 46 // Creates Node from serialization leaving weight and id undefined. 47 Node(const Val& in_value, uint32_t in_left, uint32_t in_right) 48 : value(in_value), left(in_left), right(in_right) {} 49 50 Val value = Val(); 51 uint32_t weight = 0; 52 // Ids are issued sequentially starting from 1. Ids are used as an ordering 53 // tie-breaker, to make sure that the ordering (and resulting coding scheme) 54 // are consistent accross multiple platforms. 55 uint32_t id = 0; 56 // Handles of children. 57 uint32_t left = 0; 58 uint32_t right = 0; 59 }; 60 61 // Creates Huffman codec from a histogramm. 62 // Histogramm counts must not be zero. 63 explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) { 64 if (hist.empty()) return; 65 66 // Heuristic estimate. 67 nodes_.reserve(3 * hist.size()); 68 69 // Create NIL. 70 CreateNode(); 71 72 // The queue is sorted in ascending order by weight (or by node id if 73 // weights are equal). 74 std::vector<uint32_t> queue_vector; 75 queue_vector.reserve(hist.size()); 76 std::priority_queue<uint32_t, std::vector<uint32_t>, 77 std::function<bool(uint32_t, uint32_t)>> 78 queue(std::bind(&HuffmanCodec::LeftIsBigger, this, 79 std::placeholders::_1, std::placeholders::_2), 80 std::move(queue_vector)); 81 82 // Put all leaves in the queue. 83 for (const auto& pair : hist) { 84 const uint32_t node = CreateNode(); 85 MutableValueOf(node) = pair.first; 86 MutableWeightOf(node) = pair.second; 87 assert(WeightOf(node)); 88 queue.push(node); 89 } 90 91 // Form the tree by combining two subtrees with the least weight, 92 // and pushing the root of the new tree in the queue. 93 while (true) { 94 // We push a node at the end of each iteration, so the queue is never 95 // supposed to be empty at this point, unless there are no leaves, but 96 // that case was already handled. 97 assert(!queue.empty()); 98 const uint32_t right = queue.top(); 99 queue.pop(); 100 101 // If the queue is empty at this point, then the last node is 102 // the root of the complete Huffman tree. 103 if (queue.empty()) { 104 root_ = right; 105 break; 106 } 107 108 const uint32_t left = queue.top(); 109 queue.pop(); 110 111 // Combine left and right into a new tree and push it into the queue. 112 const uint32_t parent = CreateNode(); 113 MutableWeightOf(parent) = WeightOf(right) + WeightOf(left); 114 MutableLeftOf(parent) = left; 115 MutableRightOf(parent) = right; 116 queue.push(parent); 117 } 118 119 // Traverse the tree and form encoding table. 120 CreateEncodingTable(); 121 } 122 123 // Creates Huffman codec from saved tree structure. 124 // |nodes| is the list of nodes of the tree, nodes[0] being NIL. 125 // |root_handle| is the index of the root node. 126 HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) { 127 nodes_ = std::move(nodes); 128 assert(!nodes_.empty()); 129 assert(root_handle > 0 && root_handle < nodes_.size()); 130 assert(!LeftOf(0) && !RightOf(0)); 131 132 root_ = root_handle; 133 134 // Traverse the tree and form encoding table. 135 CreateEncodingTable(); 136 } 137 138 // Serializes the codec in the following text format: 139 // (<root_handle>, { 140 // {0, 0, 0}, 141 // {val1, left1, right1}, 142 // {val2, left2, right2}, 143 // ... 144 // }) 145 std::string SerializeToText(int indent_num_whitespaces) const { 146 const bool value_is_text = std::is_same<Val, std::string>::value; 147 148 const std::string indent1 = std::string(indent_num_whitespaces, ' '); 149 const std::string indent2 = std::string(indent_num_whitespaces + 2, ' '); 150 151 std::stringstream code; 152 code << "(" << root_ << ", {\n"; 153 154 for (const Node& node : nodes_) { 155 code << indent2 << "{"; 156 if (value_is_text) 157 code << "\""; 158 code << node.value; 159 if (value_is_text) 160 code << "\""; 161 code << ", " << node.left << ", " << node.right << "},\n"; 162 } 163 164 code << indent1 << "})"; 165 166 return code.str(); 167 } 168 169 // Prints the Huffman tree in the following format: 170 // w------w------'x' 171 // w------'y' 172 // Where w stands for the weight of the node. 173 // Right tree branches appear above left branches. Taking the right path 174 // adds 1 to the code, taking the left adds 0. 175 void PrintTree(std::ostream& out) const { 176 PrintTreeInternal(out, root_, 0); 177 } 178 179 // Traverses the tree and prints the Huffman table: value, code 180 // and optionally node weight for every leaf. 181 void PrintTable(std::ostream& out, bool print_weights = true) { 182 std::queue<std::pair<uint32_t, std::string>> queue; 183 queue.emplace(root_, ""); 184 185 while (!queue.empty()) { 186 const uint32_t node = queue.front().first; 187 const std::string code = queue.front().second; 188 queue.pop(); 189 if (!RightOf(node) && !LeftOf(node)) { 190 out << ValueOf(node); 191 if (print_weights) 192 out << " " << WeightOf(node); 193 out << " " << code << std::endl; 194 } else { 195 if (LeftOf(node)) 196 queue.emplace(LeftOf(node), code + "0"); 197 198 if (RightOf(node)) 199 queue.emplace(RightOf(node), code + "1"); 200 } 201 } 202 } 203 204 // Returns the Huffman table. The table was built at at construction time, 205 // this function just returns a const reference. 206 const std::unordered_map<Val, std::pair<uint64_t, size_t>>& 207 GetEncodingTable() const { 208 return encoding_table_; 209 } 210 211 // Encodes |val| and stores its Huffman code in the lower |num_bits| of 212 // |bits|. Returns false of |val| is not in the Huffman table. 213 bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) { 214 auto it = encoding_table_.find(val); 215 if (it == encoding_table_.end()) 216 return false; 217 *bits = it->second.first; 218 *num_bits = it->second.second; 219 return true; 220 } 221 222 // Reads bits one-by-one using callback |read_bit| until a match is found. 223 // Matching value is stored in |val|. Returns false if |read_bit| terminates 224 // before a code was mathced. 225 // |read_bit| has type bool func(bool* bit). When called, the next bit is 226 // stored in |bit|. |read_bit| returns false if the stream terminates 227 // prematurely. 228 bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) { 229 uint32_t node = root_; 230 while (true) { 231 assert(node); 232 233 if (!RightOf(node) && !LeftOf(node)) { 234 *val = ValueOf(node); 235 return true; 236 } 237 238 bool go_right; 239 if (!read_bit(&go_right)) 240 return false; 241 242 if (go_right) 243 node = RightOf(node); 244 else 245 node = LeftOf(node); 246 } 247 248 assert (0); 249 return false; 250 } 251 252 private: 253 // Returns value of the node referenced by |handle|. 254 Val ValueOf(uint32_t node) const { 255 return nodes_.at(node).value; 256 } 257 258 // Returns left child of |node|. 259 uint32_t LeftOf(uint32_t node) const { 260 return nodes_.at(node).left; 261 } 262 263 // Returns right child of |node|. 264 uint32_t RightOf(uint32_t node) const { 265 return nodes_.at(node).right; 266 } 267 268 // Returns weight of |node|. 269 uint32_t WeightOf(uint32_t node) const { 270 return nodes_.at(node).weight; 271 } 272 273 // Returns id of |node|. 274 uint32_t IdOf(uint32_t node) const { 275 return nodes_.at(node).id; 276 } 277 278 // Returns mutable reference to value of |node|. 279 Val& MutableValueOf(uint32_t node) { 280 assert(node); 281 return nodes_.at(node).value; 282 } 283 284 // Returns mutable reference to handle of left child of |node|. 285 uint32_t& MutableLeftOf(uint32_t node) { 286 assert(node); 287 return nodes_.at(node).left; 288 } 289 290 // Returns mutable reference to handle of right child of |node|. 291 uint32_t& MutableRightOf(uint32_t node) { 292 assert(node); 293 return nodes_.at(node).right; 294 } 295 296 // Returns mutable reference to weight of |node|. 297 uint32_t& MutableWeightOf(uint32_t node) { 298 return nodes_.at(node).weight; 299 } 300 301 // Returns mutable reference to id of |node|. 302 uint32_t& MutableIdOf(uint32_t node) { 303 return nodes_.at(node).id; 304 } 305 306 // Returns true if |left| has bigger weight than |right|. Node ids are 307 // used as tie-breaker. 308 bool LeftIsBigger(uint32_t left, uint32_t right) const { 309 if (WeightOf(left) == WeightOf(right)) { 310 assert (IdOf(left) != IdOf(right)); 311 return IdOf(left) > IdOf(right); 312 } 313 return WeightOf(left) > WeightOf(right); 314 } 315 316 // Prints subtree (helper function used by PrintTree). 317 void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const { 318 if (!node) 319 return; 320 321 const size_t kTextFieldWidth = 7; 322 323 if (!RightOf(node) && !LeftOf(node)) { 324 out << ValueOf(node) << std::endl; 325 } else { 326 if (RightOf(node)) { 327 std::stringstream label; 328 label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) 329 << WeightOf(RightOf(node)); 330 out << label.str(); 331 PrintTreeInternal(out, RightOf(node), depth + 1); 332 } 333 334 if (LeftOf(node)) { 335 out << std::string(depth * kTextFieldWidth, ' '); 336 std::stringstream label; 337 label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) 338 << WeightOf(LeftOf(node)); 339 out << label.str(); 340 PrintTreeInternal(out, LeftOf(node), depth + 1); 341 } 342 } 343 } 344 345 // Traverses the Huffman tree and saves paths to the leaves as bit 346 // sequences to encoding_table_. 347 void CreateEncodingTable() { 348 struct Context { 349 Context(uint32_t in_node, uint64_t in_bits, size_t in_depth) 350 : node(in_node), bits(in_bits), depth(in_depth) {} 351 uint32_t node; 352 // Huffman tree depth cannot exceed 64 as histogramm counts are expected 353 // to be positive and limited by numeric_limits<uint32_t>::max(). 354 // For practical applications tree depth would be much smaller than 64. 355 uint64_t bits; 356 size_t depth; 357 }; 358 359 std::queue<Context> queue; 360 queue.emplace(root_, 0, 0); 361 362 while (!queue.empty()) { 363 const Context& context = queue.front(); 364 const uint32_t node = context.node; 365 const uint64_t bits = context.bits; 366 const size_t depth = context.depth; 367 queue.pop(); 368 369 if (!RightOf(node) && !LeftOf(node)) { 370 auto insertion_result = encoding_table_.emplace( 371 ValueOf(node), std::pair<uint64_t, size_t>(bits, depth)); 372 assert(insertion_result.second); 373 (void)insertion_result; 374 } else { 375 if (LeftOf(node)) 376 queue.emplace(LeftOf(node), bits, depth + 1); 377 378 if (RightOf(node)) 379 queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1); 380 } 381 } 382 } 383 384 // Creates new Huffman tree node and stores it in the deleter array. 385 uint32_t CreateNode() { 386 const uint32_t handle = static_cast<uint32_t>(nodes_.size()); 387 nodes_.emplace_back(Node()); 388 nodes_.back().id = next_node_id_++; 389 return handle; 390 } 391 392 // Huffman tree root handle. 393 uint32_t root_ = 0; 394 395 // Huffman tree deleter. 396 std::vector<Node> nodes_; 397 398 // Encoding table value -> {bits, num_bits}. 399 // Huffman codes are expected to never exceed 64 bit length (this is in fact 400 // impossible if frequencies are stored as uint32_t). 401 std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_; 402 403 // Next node id issued by CreateNode(); 404 uint32_t next_node_id_ = 1; 405 }; 406 407 } // namespace spvutils 408 409 #endif // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ 410