Home | History | Annotate | Download | only in comp
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // Contains utils for reading, writing and debug printing bit streams.
     16 
     17 #ifndef SOURCE_COMP_HUFFMAN_CODEC_H_
     18 #define SOURCE_COMP_HUFFMAN_CODEC_H_
     19 
     20 #include <algorithm>
     21 #include <cassert>
     22 #include <functional>
     23 #include <iomanip>
     24 #include <map>
     25 #include <memory>
     26 #include <ostream>
     27 #include <queue>
     28 #include <sstream>
     29 #include <stack>
     30 #include <string>
     31 #include <tuple>
     32 #include <unordered_map>
     33 #include <utility>
     34 #include <vector>
     35 
     36 namespace spvtools {
     37 namespace comp {
     38 
     39 // Used to generate and apply a Huffman coding scheme.
     40 // |Val| is the type of variable being encoded (for example a string or a
     41 // literal).
     42 template <class Val>
     43 class HuffmanCodec {
     44  public:
     45   // Huffman tree node.
     46   struct Node {
     47     Node() {}
     48 
     49     // Creates Node from serialization leaving weight and id undefined.
     50     Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
     51         : value(in_value), left(in_left), right(in_right) {}
     52 
     53     Val value = Val();
     54     uint32_t weight = 0;
     55     // Ids are issued sequentially starting from 1. Ids are used as an ordering
     56     // tie-breaker, to make sure that the ordering (and resulting coding scheme)
     57     // are consistent accross multiple platforms.
     58     uint32_t id = 0;
     59     // Handles of children.
     60     uint32_t left = 0;
     61     uint32_t right = 0;
     62   };
     63 
     64   // Creates Huffman codec from a histogramm.
     65   // Histogramm counts must not be zero.
     66   explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
     67     if (hist.empty()) return;
     68 
     69     // Heuristic estimate.
     70     nodes_.reserve(3 * hist.size());
     71 
     72     // Create NIL.
     73     CreateNode();
     74 
     75     // The queue is sorted in ascending order by weight (or by node id if
     76     // weights are equal).
     77     std::vector<uint32_t> queue_vector;
     78     queue_vector.reserve(hist.size());
     79     std::priority_queue<uint32_t, std::vector<uint32_t>,
     80                         std::function<bool(uint32_t, uint32_t)>>
     81         queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
     82                         std::placeholders::_1, std::placeholders::_2),
     83               std::move(queue_vector));
     84 
     85     // Put all leaves in the queue.
     86     for (const auto& pair : hist) {
     87       const uint32_t node = CreateNode();
     88       MutableValueOf(node) = pair.first;
     89       MutableWeightOf(node) = pair.second;
     90       assert(WeightOf(node));
     91       queue.push(node);
     92     }
     93 
     94     // Form the tree by combining two subtrees with the least weight,
     95     // and pushing the root of the new tree in the queue.
     96     while (true) {
     97       // We push a node at the end of each iteration, so the queue is never
     98       // supposed to be empty at this point, unless there are no leaves, but
     99       // that case was already handled.
    100       assert(!queue.empty());
    101       const uint32_t right = queue.top();
    102       queue.pop();
    103 
    104       // If the queue is empty at this point, then the last node is
    105       // the root of the complete Huffman tree.
    106       if (queue.empty()) {
    107         root_ = right;
    108         break;
    109       }
    110 
    111       const uint32_t left = queue.top();
    112       queue.pop();
    113 
    114       // Combine left and right into a new tree and push it into the queue.
    115       const uint32_t parent = CreateNode();
    116       MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
    117       MutableLeftOf(parent) = left;
    118       MutableRightOf(parent) = right;
    119       queue.push(parent);
    120     }
    121 
    122     // Traverse the tree and form encoding table.
    123     CreateEncodingTable();
    124   }
    125 
    126   // Creates Huffman codec from saved tree structure.
    127   // |nodes| is the list of nodes of the tree, nodes[0] being NIL.
    128   // |root_handle| is the index of the root node.
    129   HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
    130     nodes_ = std::move(nodes);
    131     assert(!nodes_.empty());
    132     assert(root_handle > 0 && root_handle < nodes_.size());
    133     assert(!LeftOf(0) && !RightOf(0));
    134 
    135     root_ = root_handle;
    136 
    137     // Traverse the tree and form encoding table.
    138     CreateEncodingTable();
    139   }
    140 
    141   // Serializes the codec in the following text format:
    142   // (<root_handle>, {
    143   //   {0, 0, 0},
    144   //   {val1, left1, right1},
    145   //   {val2, left2, right2},
    146   //   ...
    147   // })
    148   std::string SerializeToText(int indent_num_whitespaces) const {
    149     const bool value_is_text = std::is_same<Val, std::string>::value;
    150 
    151     const std::string indent1 = std::string(indent_num_whitespaces, ' ');
    152     const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
    153 
    154     std::stringstream code;
    155     code << "(" << root_ << ", {\n";
    156 
    157     for (const Node& node : nodes_) {
    158       code << indent2 << "{";
    159       if (value_is_text) code << "\"";
    160       code << node.value;
    161       if (value_is_text) code << "\"";
    162       code << ", " << node.left << ", " << node.right << "},\n";
    163     }
    164 
    165     code << indent1 << "})";
    166 
    167     return code.str();
    168   }
    169 
    170   // Prints the Huffman tree in the following format:
    171   // w------w------'x'
    172   //        w------'y'
    173   // Where w stands for the weight of the node.
    174   // Right tree branches appear above left branches. Taking the right path
    175   // adds 1 to the code, taking the left adds 0.
    176   void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); }
    177 
    178   // Traverses the tree and prints the Huffman table: value, code
    179   // and optionally node weight for every leaf.
    180   void PrintTable(std::ostream& out, bool print_weights = true) {
    181     std::queue<std::pair<uint32_t, std::string>> queue;
    182     queue.emplace(root_, "");
    183 
    184     while (!queue.empty()) {
    185       const uint32_t node = queue.front().first;
    186       const std::string code = queue.front().second;
    187       queue.pop();
    188       if (!RightOf(node) && !LeftOf(node)) {
    189         out << ValueOf(node);
    190         if (print_weights) out << " " << WeightOf(node);
    191         out << " " << code << std::endl;
    192       } else {
    193         if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0");
    194 
    195         if (RightOf(node)) queue.emplace(RightOf(node), code + "1");
    196       }
    197     }
    198   }
    199 
    200   // Returns the Huffman table. The table was built at at construction time,
    201   // this function just returns a const reference.
    202   const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()
    203       const {
    204     return encoding_table_;
    205   }
    206 
    207   // Encodes |val| and stores its Huffman code in the lower |num_bits| of
    208   // |bits|. Returns false of |val| is not in the Huffman table.
    209   bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const {
    210     auto it = encoding_table_.find(val);
    211     if (it == encoding_table_.end()) return false;
    212     *bits = it->second.first;
    213     *num_bits = it->second.second;
    214     return true;
    215   }
    216 
    217   // Reads bits one-by-one using callback |read_bit| until a match is found.
    218   // Matching value is stored in |val|. Returns false if |read_bit| terminates
    219   // before a code was mathced.
    220   // |read_bit| has type bool func(bool* bit). When called, the next bit is
    221   // stored in |bit|. |read_bit| returns false if the stream terminates
    222   // prematurely.
    223   bool DecodeFromStream(const std::function<bool(bool*)>& read_bit,
    224                         Val* val) const {
    225     uint32_t node = root_;
    226     while (true) {
    227       assert(node);
    228 
    229       if (!RightOf(node) && !LeftOf(node)) {
    230         *val = ValueOf(node);
    231         return true;
    232       }
    233 
    234       bool go_right;
    235       if (!read_bit(&go_right)) return false;
    236 
    237       if (go_right)
    238         node = RightOf(node);
    239       else
    240         node = LeftOf(node);
    241     }
    242 
    243     assert(0);
    244     return false;
    245   }
    246 
    247  private:
    248   // Returns value of the node referenced by |handle|.
    249   Val ValueOf(uint32_t node) const { return nodes_.at(node).value; }
    250 
    251   // Returns left child of |node|.
    252   uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
    253 
    254   // Returns right child of |node|.
    255   uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
    256 
    257   // Returns weight of |node|.
    258   uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; }
    259 
    260   // Returns id of |node|.
    261   uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; }
    262 
    263   // Returns mutable reference to value of |node|.
    264   Val& MutableValueOf(uint32_t node) {
    265     assert(node);
    266     return nodes_.at(node).value;
    267   }
    268 
    269   // Returns mutable reference to handle of left child of |node|.
    270   uint32_t& MutableLeftOf(uint32_t node) {
    271     assert(node);
    272     return nodes_.at(node).left;
    273   }
    274 
    275   // Returns mutable reference to handle of right child of |node|.
    276   uint32_t& MutableRightOf(uint32_t node) {
    277     assert(node);
    278     return nodes_.at(node).right;
    279   }
    280 
    281   // Returns mutable reference to weight of |node|.
    282   uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; }
    283 
    284   // Returns mutable reference to id of |node|.
    285   uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; }
    286 
    287   // Returns true if |left| has bigger weight than |right|. Node ids are
    288   // used as tie-breaker.
    289   bool LeftIsBigger(uint32_t left, uint32_t right) const {
    290     if (WeightOf(left) == WeightOf(right)) {
    291       assert(IdOf(left) != IdOf(right));
    292       return IdOf(left) > IdOf(right);
    293     }
    294     return WeightOf(left) > WeightOf(right);
    295   }
    296 
    297   // Prints subtree (helper function used by PrintTree).
    298   void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
    299     if (!node) return;
    300 
    301     const size_t kTextFieldWidth = 7;
    302 
    303     if (!RightOf(node) && !LeftOf(node)) {
    304       out << ValueOf(node) << std::endl;
    305     } else {
    306       if (RightOf(node)) {
    307         std::stringstream label;
    308         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
    309               << WeightOf(RightOf(node));
    310         out << label.str();
    311         PrintTreeInternal(out, RightOf(node), depth + 1);
    312       }
    313 
    314       if (LeftOf(node)) {
    315         out << std::string(depth * kTextFieldWidth, ' ');
    316         std::stringstream label;
    317         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
    318               << WeightOf(LeftOf(node));
    319         out << label.str();
    320         PrintTreeInternal(out, LeftOf(node), depth + 1);
    321       }
    322     }
    323   }
    324 
    325   // Traverses the Huffman tree and saves paths to the leaves as bit
    326   // sequences to encoding_table_.
    327   void CreateEncodingTable() {
    328     struct Context {
    329       Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
    330           : node(in_node), bits(in_bits), depth(in_depth) {}
    331       uint32_t node;
    332       // Huffman tree depth cannot exceed 64 as histogramm counts are expected
    333       // to be positive and limited by numeric_limits<uint32_t>::max().
    334       // For practical applications tree depth would be much smaller than 64.
    335       uint64_t bits;
    336       size_t depth;
    337     };
    338 
    339     std::queue<Context> queue;
    340     queue.emplace(root_, 0, 0);
    341 
    342     while (!queue.empty()) {
    343       const Context& context = queue.front();
    344       const uint32_t node = context.node;
    345       const uint64_t bits = context.bits;
    346       const size_t depth = context.depth;
    347       queue.pop();
    348 
    349       if (!RightOf(node) && !LeftOf(node)) {
    350         auto insertion_result = encoding_table_.emplace(
    351             ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
    352         assert(insertion_result.second);
    353         (void)insertion_result;
    354       } else {
    355         if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1);
    356 
    357         if (RightOf(node))
    358           queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
    359       }
    360     }
    361   }
    362 
    363   // Creates new Huffman tree node and stores it in the deleter array.
    364   uint32_t CreateNode() {
    365     const uint32_t handle = static_cast<uint32_t>(nodes_.size());
    366     nodes_.emplace_back(Node());
    367     nodes_.back().id = next_node_id_++;
    368     return handle;
    369   }
    370 
    371   // Huffman tree root handle.
    372   uint32_t root_ = 0;
    373 
    374   // Huffman tree deleter.
    375   std::vector<Node> nodes_;
    376 
    377   // Encoding table value -> {bits, num_bits}.
    378   // Huffman codes are expected to never exceed 64 bit length (this is in fact
    379   // impossible if frequencies are stored as uint32_t).
    380   std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
    381 
    382   // Next node id issued by CreateNode();
    383   uint32_t next_node_id_ = 1;
    384 };
    385 
    386 }  // namespace comp
    387 }  // namespace spvtools
    388 
    389 #endif  // SOURCE_COMP_HUFFMAN_CODEC_H_
    390