Home | History | Annotate | Download | only in util
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // Contains utils for reading, writing and debug printing bit streams.
     16 
     17 #ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
     18 #define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
     19 
     20 #include <algorithm>
     21 #include <cassert>
     22 #include <functional>
     23 #include <queue>
     24 #include <iomanip>
     25 #include <map>
     26 #include <memory>
     27 #include <ostream>
     28 #include <sstream>
     29 #include <stack>
     30 #include <tuple>
     31 #include <unordered_map>
     32 #include <vector>
     33 
     34 namespace spvutils {
     35 
     36 // Used to generate and apply a Huffman coding scheme.
     37 // |Val| is the type of variable being encoded (for example a string or a
     38 // literal).
     39 template <class Val>
     40 class HuffmanCodec {
     41  public:
     42   // Huffman tree node.
     43   struct Node {
     44     Node() {}
     45 
     46     // Creates Node from serialization leaving weight and id undefined.
     47     Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
     48         : value(in_value), left(in_left), right(in_right) {}
     49 
     50     Val value = Val();
     51     uint32_t weight = 0;
     52     // Ids are issued sequentially starting from 1. Ids are used as an ordering
     53     // tie-breaker, to make sure that the ordering (and resulting coding scheme)
     54     // are consistent accross multiple platforms.
     55     uint32_t id = 0;
     56     // Handles of children.
     57     uint32_t left = 0;
     58     uint32_t right = 0;
     59   };
     60 
     61   // Creates Huffman codec from a histogramm.
     62   // Histogramm counts must not be zero.
     63   explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
     64     if (hist.empty()) return;
     65 
     66     // Heuristic estimate.
     67     nodes_.reserve(3 * hist.size());
     68 
     69     // Create NIL.
     70     CreateNode();
     71 
     72     // The queue is sorted in ascending order by weight (or by node id if
     73     // weights are equal).
     74     std::vector<uint32_t> queue_vector;
     75     queue_vector.reserve(hist.size());
     76     std::priority_queue<uint32_t, std::vector<uint32_t>,
     77         std::function<bool(uint32_t, uint32_t)>>
     78             queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
     79                             std::placeholders::_1, std::placeholders::_2),
     80                   std::move(queue_vector));
     81 
     82     // Put all leaves in the queue.
     83     for (const auto& pair : hist) {
     84       const uint32_t node = CreateNode();
     85       MutableValueOf(node) = pair.first;
     86       MutableWeightOf(node) = pair.second;
     87       assert(WeightOf(node));
     88       queue.push(node);
     89     }
     90 
     91     // Form the tree by combining two subtrees with the least weight,
     92     // and pushing the root of the new tree in the queue.
     93     while (true) {
     94       // We push a node at the end of each iteration, so the queue is never
     95       // supposed to be empty at this point, unless there are no leaves, but
     96       // that case was already handled.
     97       assert(!queue.empty());
     98       const uint32_t right = queue.top();
     99       queue.pop();
    100 
    101       // If the queue is empty at this point, then the last node is
    102       // the root of the complete Huffman tree.
    103       if (queue.empty()) {
    104         root_ = right;
    105         break;
    106       }
    107 
    108       const uint32_t left = queue.top();
    109       queue.pop();
    110 
    111       // Combine left and right into a new tree and push it into the queue.
    112       const uint32_t parent = CreateNode();
    113       MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
    114       MutableLeftOf(parent) = left;
    115       MutableRightOf(parent) = right;
    116       queue.push(parent);
    117     }
    118 
    119     // Traverse the tree and form encoding table.
    120     CreateEncodingTable();
    121   }
    122 
    123   // Creates Huffman codec from saved tree structure.
    124   // |nodes| is the list of nodes of the tree, nodes[0] being NIL.
    125   // |root_handle| is the index of the root node.
    126   HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
    127     nodes_ = std::move(nodes);
    128     assert(!nodes_.empty());
    129     assert(root_handle > 0 && root_handle < nodes_.size());
    130     assert(!LeftOf(0) && !RightOf(0));
    131 
    132     root_ = root_handle;
    133 
    134     // Traverse the tree and form encoding table.
    135     CreateEncodingTable();
    136   }
    137 
    138   // Serializes the codec in the following text format:
    139   // (<root_handle>, {
    140   //   {0, 0, 0},
    141   //   {val1, left1, right1},
    142   //   {val2, left2, right2},
    143   //   ...
    144   // })
    145   std::string SerializeToText(int indent_num_whitespaces) const {
    146     const bool value_is_text = std::is_same<Val, std::string>::value;
    147 
    148     const std::string indent1 = std::string(indent_num_whitespaces, ' ');
    149     const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
    150 
    151     std::stringstream code;
    152     code << "(" << root_ << ", {\n";
    153 
    154     for (const Node& node : nodes_) {
    155       code << indent2 << "{";
    156       if (value_is_text)
    157         code << "\"";
    158       code << node.value;
    159       if (value_is_text)
    160         code << "\"";
    161       code << ", " << node.left << ", " << node.right << "},\n";
    162     }
    163 
    164     code << indent1 << "})";
    165 
    166     return code.str();
    167   }
    168 
    169   // Prints the Huffman tree in the following format:
    170   // w------w------'x'
    171   //        w------'y'
    172   // Where w stands for the weight of the node.
    173   // Right tree branches appear above left branches. Taking the right path
    174   // adds 1 to the code, taking the left adds 0.
    175   void PrintTree(std::ostream& out) const {
    176     PrintTreeInternal(out, root_, 0);
    177   }
    178 
    179   // Traverses the tree and prints the Huffman table: value, code
    180   // and optionally node weight for every leaf.
    181   void PrintTable(std::ostream& out, bool print_weights = true) {
    182     std::queue<std::pair<uint32_t, std::string>> queue;
    183     queue.emplace(root_, "");
    184 
    185     while (!queue.empty()) {
    186       const uint32_t node = queue.front().first;
    187       const std::string code = queue.front().second;
    188       queue.pop();
    189       if (!RightOf(node) && !LeftOf(node)) {
    190         out << ValueOf(node);
    191         if (print_weights)
    192             out << " " << WeightOf(node);
    193         out << " " << code << std::endl;
    194       } else {
    195         if (LeftOf(node))
    196           queue.emplace(LeftOf(node), code + "0");
    197 
    198         if (RightOf(node))
    199           queue.emplace(RightOf(node), code + "1");
    200       }
    201     }
    202   }
    203 
    204   // Returns the Huffman table. The table was built at at construction time,
    205   // this function just returns a const reference.
    206   const std::unordered_map<Val, std::pair<uint64_t, size_t>>&
    207       GetEncodingTable() const {
    208     return encoding_table_;
    209   }
    210 
    211   // Encodes |val| and stores its Huffman code in the lower |num_bits| of
    212   // |bits|. Returns false of |val| is not in the Huffman table.
    213   bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) {
    214     auto it = encoding_table_.find(val);
    215     if (it == encoding_table_.end())
    216       return false;
    217     *bits = it->second.first;
    218     *num_bits = it->second.second;
    219     return true;
    220   }
    221 
    222   // Reads bits one-by-one using callback |read_bit| until a match is found.
    223   // Matching value is stored in |val|. Returns false if |read_bit| terminates
    224   // before a code was mathced.
    225   // |read_bit| has type bool func(bool* bit). When called, the next bit is
    226   // stored in |bit|. |read_bit| returns false if the stream terminates
    227   // prematurely.
    228   bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) {
    229     uint32_t node = root_;
    230     while (true) {
    231       assert(node);
    232 
    233       if (!RightOf(node) && !LeftOf(node)) {
    234         *val = ValueOf(node);
    235         return true;
    236       }
    237 
    238       bool go_right;
    239       if (!read_bit(&go_right))
    240         return false;
    241 
    242       if (go_right)
    243         node = RightOf(node);
    244       else
    245         node = LeftOf(node);
    246     }
    247 
    248     assert (0);
    249     return false;
    250   }
    251 
    252  private:
    253   // Returns value of the node referenced by |handle|.
    254   Val ValueOf(uint32_t node) const {
    255     return nodes_.at(node).value;
    256   }
    257 
    258   // Returns left child of |node|.
    259   uint32_t LeftOf(uint32_t node) const {
    260     return nodes_.at(node).left;
    261   }
    262 
    263   // Returns right child of |node|.
    264   uint32_t RightOf(uint32_t node) const {
    265     return nodes_.at(node).right;
    266   }
    267 
    268   // Returns weight of |node|.
    269   uint32_t WeightOf(uint32_t node) const {
    270     return nodes_.at(node).weight;
    271   }
    272 
    273   // Returns id of |node|.
    274   uint32_t IdOf(uint32_t node) const {
    275     return nodes_.at(node).id;
    276   }
    277 
    278   // Returns mutable reference to value of |node|.
    279   Val& MutableValueOf(uint32_t node) {
    280     assert(node);
    281     return nodes_.at(node).value;
    282   }
    283 
    284   // Returns mutable reference to handle of left child of |node|.
    285   uint32_t& MutableLeftOf(uint32_t node) {
    286     assert(node);
    287     return nodes_.at(node).left;
    288   }
    289 
    290   // Returns mutable reference to handle of right child of |node|.
    291   uint32_t& MutableRightOf(uint32_t node) {
    292     assert(node);
    293     return nodes_.at(node).right;
    294   }
    295 
    296   // Returns mutable reference to weight of |node|.
    297   uint32_t& MutableWeightOf(uint32_t node) {
    298     return nodes_.at(node).weight;
    299   }
    300 
    301   // Returns mutable reference to id of |node|.
    302   uint32_t& MutableIdOf(uint32_t node) {
    303     return nodes_.at(node).id;
    304   }
    305 
    306   // Returns true if |left| has bigger weight than |right|. Node ids are
    307   // used as tie-breaker.
    308   bool LeftIsBigger(uint32_t left, uint32_t right) const {
    309     if (WeightOf(left) == WeightOf(right)) {
    310       assert (IdOf(left) != IdOf(right));
    311       return IdOf(left) > IdOf(right);
    312     }
    313     return WeightOf(left) > WeightOf(right);
    314   }
    315 
    316   // Prints subtree (helper function used by PrintTree).
    317   void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
    318     if (!node)
    319       return;
    320 
    321     const size_t kTextFieldWidth = 7;
    322 
    323     if (!RightOf(node) && !LeftOf(node)) {
    324       out << ValueOf(node) << std::endl;
    325     } else {
    326       if (RightOf(node)) {
    327         std::stringstream label;
    328         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
    329               << WeightOf(RightOf(node));
    330         out << label.str();
    331         PrintTreeInternal(out, RightOf(node), depth + 1);
    332       }
    333 
    334       if (LeftOf(node)) {
    335         out << std::string(depth * kTextFieldWidth, ' ');
    336         std::stringstream label;
    337         label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
    338               << WeightOf(LeftOf(node));
    339         out << label.str();
    340         PrintTreeInternal(out, LeftOf(node), depth + 1);
    341       }
    342     }
    343   }
    344 
    345   // Traverses the Huffman tree and saves paths to the leaves as bit
    346   // sequences to encoding_table_.
    347   void CreateEncodingTable() {
    348     struct Context {
    349       Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
    350           :  node(in_node), bits(in_bits), depth(in_depth) {}
    351       uint32_t node;
    352       // Huffman tree depth cannot exceed 64 as histogramm counts are expected
    353       // to be positive and limited by numeric_limits<uint32_t>::max().
    354       // For practical applications tree depth would be much smaller than 64.
    355       uint64_t bits;
    356       size_t depth;
    357     };
    358 
    359     std::queue<Context> queue;
    360     queue.emplace(root_, 0, 0);
    361 
    362     while (!queue.empty()) {
    363       const Context& context = queue.front();
    364       const uint32_t node = context.node;
    365       const uint64_t bits = context.bits;
    366       const size_t depth = context.depth;
    367       queue.pop();
    368 
    369       if (!RightOf(node) && !LeftOf(node)) {
    370         auto insertion_result = encoding_table_.emplace(
    371             ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
    372         assert(insertion_result.second);
    373         (void)insertion_result;
    374       } else {
    375         if (LeftOf(node))
    376           queue.emplace(LeftOf(node), bits, depth + 1);
    377 
    378         if (RightOf(node))
    379           queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
    380       }
    381     }
    382   }
    383 
    384   // Creates new Huffman tree node and stores it in the deleter array.
    385   uint32_t CreateNode() {
    386     const uint32_t handle = static_cast<uint32_t>(nodes_.size());
    387     nodes_.emplace_back(Node());
    388     nodes_.back().id = next_node_id_++;
    389     return handle;
    390   }
    391 
    392   // Huffman tree root handle.
    393   uint32_t root_ = 0;
    394 
    395   // Huffman tree deleter.
    396   std::vector<Node> nodes_;
    397 
    398   // Encoding table value -> {bits, num_bits}.
    399   // Huffman codes are expected to never exceed 64 bit length (this is in fact
    400   // impossible if frequencies are stored as uint32_t).
    401   std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
    402 
    403   // Next node id issued by CreateNode();
    404   uint32_t next_node_id_ = 1;
    405 };
    406 
    407 }  // namespace spvutils
    408 
    409 #endif  // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
    410