diff options
author | Andrey Tuganov <andreyt@google.com> | 2017-08-11 16:51:24 -0400 |
---|---|---|
committer | David Neto <dneto@google.com> | 2017-08-15 23:57:21 -0400 |
commit | 17d941af4fa87e920ae2779cb2f3b8decd99a9a0 (patch) | |
tree | 3b84bcd8d7d59a8aba24639767eab92e9b18182d | |
parent | 1d477b9898006887129a972d59719d9294a80c31 (diff) | |
download | SPIRV-Tools-17d941af4fa87e920ae2779cb2f3b8decd99a9a0.tar.gz SPIRV-Tools-17d941af4fa87e920ae2779cb2f3b8decd99a9a0.tar.bz2 SPIRV-Tools-17d941af4fa87e920ae2779cb2f3b8decd99a9a0.zip |
Huffman codec can serialize to text
Refactored the Huffman codec implementation and added ability to
serialize to C++-like text format. This would reduce the time-complexity
if loading hard-coded codecs.
-rw-r--r-- | source/util/huffman_codec.h | 256 | ||||
-rw-r--r-- | test/huffman_codec.cpp | 98 |
2 files changed, 281 insertions, 73 deletions
diff --git a/source/util/huffman_codec.h b/source/util/huffman_codec.h index 2e74d6b8..35880203 100644 --- a/source/util/huffman_codec.h +++ b/source/util/huffman_codec.h @@ -38,31 +38,53 @@ namespace spvutils { // literal). template <class Val> class HuffmanCodec { - struct Node; - public: + // Huffman tree node. + struct Node { + Node() {} + + // Creates Node from serialization leaving weight and id undefined. + Node(const Val& in_value, uint32_t in_left, uint32_t in_right) + : value(in_value), left(in_left), right(in_right) {} + + Val value = Val(); + uint32_t weight = 0; + // Ids are issued sequentially starting from 1. Ids are used as an ordering + // tie-breaker, to make sure that the ordering (and resulting coding scheme) + // are consistent accross multiple platforms. + uint32_t id = 0; + // Handles of children. + uint32_t left = 0; + uint32_t right = 0; + }; + // Creates Huffman codec from a histogramm. // Histogramm counts must not be zero. explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) { if (hist.empty()) return; // Heuristic estimate. - all_nodes_.reserve(3 * hist.size()); + nodes_.reserve(3 * hist.size()); + + // Create NIL. + CreateNode(); // The queue is sorted in ascending order by weight (or by node id if // weights are equal). - std::vector<Node*> queue_vector; + std::vector<uint32_t> queue_vector; queue_vector.reserve(hist.size()); - std::priority_queue<Node*, std::vector<Node*>, - std::function<bool(const Node*, const Node*)>> - queue(LeftIsBigger, std::move(queue_vector)); + std::priority_queue<uint32_t, std::vector<uint32_t>, + std::function<bool(uint32_t, uint32_t)>> + queue(std::bind(&HuffmanCodec::LeftIsBigger, this, + std::placeholders::_1, std::placeholders::_2), + std::move(queue_vector)); // Put all leaves in the queue. for (const auto& pair : hist) { - Node* node = CreateNode(); - node->val = pair.first; - node->weight = pair.second; - assert(node->weight); + const uint32_t node = CreateNode(); + MutableValueOf(node) = pair.first; + MutableWeightOf(node) = pair.second; + assert(WeightOf(node)); queue.push(node); } @@ -73,7 +95,7 @@ class HuffmanCodec { // supposed to be empty at this point, unless there are no leaves, but // that case was already handled. assert(!queue.empty()); - Node* right = queue.top(); + const uint32_t right = queue.top(); queue.pop(); // If the queue is empty at this point, then the last node is @@ -83,14 +105,14 @@ class HuffmanCodec { break; } - Node* left = queue.top(); + const uint32_t left = queue.top(); queue.pop(); // Combine left and right into a new tree and push it into the queue. - Node* parent = CreateNode(); - parent->weight = right->weight + left->weight; - parent->left = left; - parent->right = right; + const uint32_t parent = CreateNode(); + MutableWeightOf(parent) = WeightOf(right) + WeightOf(left); + MutableLeftOf(parent) = left; + MutableRightOf(parent) = right; queue.push(parent); } @@ -98,37 +120,83 @@ class HuffmanCodec { CreateEncodingTable(); } + // Creates Huffman codec from saved tree structure. + // |nodes| is the list of nodes of the tree, nodes[0] being NIL. + // |root_handle| is the index of the root node. + HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) { + nodes_ = std::move(nodes); + assert(!nodes_.empty()); + assert(root_handle > 0 && root_handle < nodes_.size()); + assert(!LeftOf(0) && !RightOf(0)); + + root_ = root_handle; + + // Traverse the tree and form encoding table. + CreateEncodingTable(); + } + + // Serializes the codec in the following text format: + // (<root_handle>, { + // {0, 0, 0}, + // {val1, left1, right1}, + // {val2, left2, right2}, + // ... + // }) + std::string SerializeToText(int indent_num_whitespaces) const { + const bool value_is_text = std::is_same<Val, std::string>::value; + + const std::string indent1 = std::string(indent_num_whitespaces, ' '); + const std::string indent2 = std::string(indent_num_whitespaces + 2, ' '); + + std::stringstream code; + code << "(" << root_ << ", {\n"; + + for (const Node& node : nodes_) { + code << indent2 << "{"; + if (value_is_text) + code << "\""; + code << node.value; + if (value_is_text) + code << "\""; + code << ", " << node.left << ", " << node.right << "},\n"; + } + + code << indent1 << "})"; + + return code.str(); + } + // Prints the Huffman tree in the following format: // w------w------'x' // w------'y' // Where w stands for the weight of the node. // Right tree branches appear above left branches. Taking the right path // adds 1 to the code, taking the left adds 0. - void PrintTree(std::ostream& out) { + void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); } // Traverses the tree and prints the Huffman table: value, code // and optionally node weight for every leaf. void PrintTable(std::ostream& out, bool print_weights = true) { - std::queue<std::pair<Node*, std::string>> queue; + std::queue<std::pair<uint32_t, std::string>> queue; queue.emplace(root_, ""); while (!queue.empty()) { - const Node* node = queue.front().first; + const uint32_t node = queue.front().first; const std::string code = queue.front().second; queue.pop(); - if (!node->right && !node->left) { - out << node->val; + if (!RightOf(node) && !LeftOf(node)) { + out << ValueOf(node); if (print_weights) - out << " " << node->weight; + out << " " << WeightOf(node); out << " " << code << std::endl; } else { - if (node->left) - queue.emplace(node->left, code + "0"); + if (LeftOf(node)) + queue.emplace(LeftOf(node), code + "0"); - if (node->right) - queue.emplace(node->right, code + "1"); + if (RightOf(node)) + queue.emplace(RightOf(node), code + "1"); } } } @@ -158,12 +226,12 @@ class HuffmanCodec { // stored in |bit|. |read_bit| returns false if the stream terminates // prematurely. bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) { - Node* node = root_; + uint32_t node = root_; while (true) { assert(node); - if (node->left == nullptr && node->right == nullptr) { - *val = node->val; + if (!RightOf(node) && !LeftOf(node)) { + *val = ValueOf(node); return true; } @@ -172,9 +240,9 @@ class HuffmanCodec { return false; if (go_right) - node = node->right; + node = RightOf(node); else - node = node->left; + node = LeftOf(node); } assert (0); @@ -182,53 +250,94 @@ class HuffmanCodec { } private: - // Huffman tree node. - struct Node { - Val val = Val(); - uint32_t weight = 0; - // Ids are issued sequentially starting from 1. Ids are used as an ordering - // tie-breaker, to make sure that the ordering (and resulting coding scheme) - // are consistent accross multiple platforms. - uint32_t id = 0; - Node* left = nullptr; - Node* right = nullptr; - }; + // Returns value of the node referenced by |handle|. + Val ValueOf(uint32_t node) const { + return nodes_.at(node).value; + } + + // Returns left child of |node|. + uint32_t LeftOf(uint32_t node) const { + return nodes_.at(node).left; + } + + // Returns right child of |node|. + uint32_t RightOf(uint32_t node) const { + return nodes_.at(node).right; + } + + // Returns weight of |node|. + uint32_t WeightOf(uint32_t node) const { + return nodes_.at(node).weight; + } + + // Returns id of |node|. + uint32_t IdOf(uint32_t node) const { + return nodes_.at(node).id; + } + + // Returns mutable reference to value of |node|. + Val& MutableValueOf(uint32_t node) { + assert(node); + return nodes_.at(node).value; + } + + // Returns mutable reference to handle of left child of |node|. + uint32_t& MutableLeftOf(uint32_t node) { + assert(node); + return nodes_.at(node).left; + } + + // Returns mutable reference to handle of right child of |node|. + uint32_t& MutableRightOf(uint32_t node) { + assert(node); + return nodes_.at(node).right; + } + + // Returns mutable reference to weight of |node|. + uint32_t& MutableWeightOf(uint32_t node) { + return nodes_.at(node).weight; + } + + // Returns mutable reference to id of |node|. + uint32_t& MutableIdOf(uint32_t node) { + return nodes_.at(node).id; + } // Returns true if |left| has bigger weight than |right|. Node ids are // used as tie-breaker. - static bool LeftIsBigger(const Node* left, const Node* right) { - if (left->weight == right->weight) { - assert (left->id != right->id); - return left->id > right->id; + bool LeftIsBigger(uint32_t left, uint32_t right) const { + if (WeightOf(left) == WeightOf(right)) { + assert (IdOf(left) != IdOf(right)); + return IdOf(left) > IdOf(right); } - return left->weight > right->weight; + return WeightOf(left) > WeightOf(right); } // Prints subtree (helper function used by PrintTree). - static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) { + void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const { if (!node) return; const size_t kTextFieldWidth = 7; - if (!node->right && !node->left) { - out << node->val << std::endl; + if (!RightOf(node) && !LeftOf(node)) { + out << ValueOf(node) << std::endl; } else { - if (node->right) { + if (RightOf(node)) { std::stringstream label; label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) - << node->right->weight; + << WeightOf(RightOf(node)); out << label.str(); - PrintTreeInternal(out, node->right, depth + 1); + PrintTreeInternal(out, RightOf(node), depth + 1); } - if (node->left) { + if (LeftOf(node)) { out << std::string(depth * kTextFieldWidth, ' '); std::stringstream label; label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) - << node->left->weight; + << WeightOf(LeftOf(node)); out << label.str(); - PrintTreeInternal(out, node->left, depth + 1); + PrintTreeInternal(out, LeftOf(node), depth + 1); } } } @@ -237,9 +346,9 @@ class HuffmanCodec { // sequences to encoding_table_. void CreateEncodingTable() { struct Context { - Context(Node* in_node, uint64_t in_bits, size_t in_depth) + Context(uint32_t in_node, uint64_t in_bits, size_t in_depth) : node(in_node), bits(in_bits), depth(in_depth) {} - Node* node; + uint32_t node; // Huffman tree depth cannot exceed 64 as histogramm counts are expected // to be positive and limited by numeric_limits<uint32_t>::max(). // For practical applications tree depth would be much smaller than 64. @@ -252,38 +361,39 @@ class HuffmanCodec { while (!queue.empty()) { const Context& context = queue.front(); - const Node* node = context.node; + const uint32_t node = context.node; const uint64_t bits = context.bits; const size_t depth = context.depth; queue.pop(); - if (!node->right && !node->left) { + if (!RightOf(node) && !LeftOf(node)) { auto insertion_result = encoding_table_.emplace( - node->val, std::pair<uint64_t, size_t>(bits, depth)); + ValueOf(node), std::pair<uint64_t, size_t>(bits, depth)); assert(insertion_result.second); (void)insertion_result; } else { - if (node->left) - queue.emplace(node->left, bits, depth + 1); + if (LeftOf(node)) + queue.emplace(LeftOf(node), bits, depth + 1); - if (node->right) - queue.emplace(node->right, bits | (1ULL << depth), depth + 1); + if (RightOf(node)) + queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1); } } } // Creates new Huffman tree node and stores it in the deleter array. - Node* CreateNode() { - all_nodes_.emplace_back(new Node()); - all_nodes_.back()->id = next_node_id_++; - return all_nodes_.back().get(); + uint32_t CreateNode() { + const uint32_t handle = static_cast<uint32_t>(nodes_.size()); + nodes_.emplace_back(Node()); + nodes_.back().id = next_node_id_++; + return handle; } - // Huffman tree root. - Node* root_ = nullptr; + // Huffman tree root handle. + uint32_t root_ = 0; // Huffman tree deleter. - std::vector<std::unique_ptr<Node>> all_nodes_; + std::vector<Node> nodes_; // Encoding table value -> {bits, num_bits}. // Huffman codes are expected to never exceed 64 bit length (this is in fact diff --git a/test/huffman_codec.cpp b/test/huffman_codec.cpp index 80f7d8f8..abad1c09 100644 --- a/test/huffman_codec.cpp +++ b/test/huffman_codec.cpp @@ -217,4 +217,102 @@ TEST(Huffman, TestDecodeNumbers) { EXPECT_EQ(3u, decoded); } +TEST(Huffman, SerializeToTextU64) { + const std::map<uint64_t, uint32_t> hist = + { {1001, 10}, {1002, 5}, {1003, 15} }; + HuffmanCodec<uint64_t> huffman(hist); + + const std::string code = huffman.SerializeToText(2); + + const std::string expected = R"((5, { + {0, 0, 0}, + {1001, 0, 0}, + {1002, 0, 0}, + {1003, 0, 0}, + {0, 1, 2}, + {0, 4, 3}, + }))"; + + + ASSERT_EQ(expected, code); +} + +TEST(Huffman, SerializeToTextString) { + const std::map<std::string, uint32_t> hist = + { {"aaa", 10}, {"bbb", 20}, {"ccc", 15} }; + HuffmanCodec<std::string> huffman(hist); + + const std::string code = huffman.SerializeToText(4); + + const std::string expected = R"((5, { + {"", 0, 0}, + {"aaa", 0, 0}, + {"bbb", 0, 0}, + {"ccc", 0, 0}, + {"", 3, 1}, + {"", 4, 2}, + }))"; + + ASSERT_EQ(expected, code); +} + +TEST(Huffman, CreateFromTextString) { + std::vector<HuffmanCodec<std::string>::Node> nodes = { + {}, + {"root", 2, 3}, + {"left", 0, 0}, + {"right", 0, 0}, + }; + + HuffmanCodec<std::string> huffman(1, std::move(nodes)); + + std::stringstream ss; + huffman.PrintTree(ss); + + const std::string expected = std::string(R"( +0------right +0------left +)").substr(1); + + EXPECT_EQ(expected, ss.str()); +} + +TEST(Huffman, CreateFromTextU64) { + HuffmanCodec<uint64_t> huffman(5, { + {0, 0, 0}, + {1001, 0, 0}, + {1002, 0, 0}, + {1003, 0, 0}, + {0, 1, 2}, + {0, 4, 3}, + }); + + std::stringstream ss; + huffman.PrintTree(ss); + + const std::string expected = std::string(R"( +0------1003 +0------0------1002 + 0------1001 +)").substr(1); + + EXPECT_EQ(expected, ss.str()); + + TestBitReader bit_reader("01"); + auto read_bit = [&bit_reader](bool* bit) { + return bit_reader.ReadBit(bit); + }; + + uint64_t decoded = 0; + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(1002u, decoded); + + uint64_t bits = 0; + size_t num_bits = 0; + + EXPECT_TRUE(huffman.Encode(1001, &bits, &num_bits)); + EXPECT_EQ(2u, num_bits); + EXPECT_EQ("00", BitsToStream(bits, num_bits)); +} + } // anonymous namespace |