From cf75229fb0839181babdce2aceaaba9ba4d25ce4 Mon Sep 17 00:00:00 2001 From: Rekai Musuka Date: Mon, 17 Apr 2023 22:06:18 -0500 Subject: [PATCH] feat: genericize HashArrayMappedTrie - actually run tests when using `zig build test` --- build.zig | 6 +- src/HashArrayMappedTrie.zig | 238 -------------------------------- src/main.zig | 62 ++++++--- src/tests.zig | 3 +- src/trie.zig | 266 ++++++++++++++++++++++++++++++++++++ 5 files changed, 313 insertions(+), 262 deletions(-) delete mode 100644 src/HashArrayMappedTrie.zig create mode 100644 src/trie.zig diff --git a/build.zig b/build.zig index 0295765..ce9f85c 100644 --- a/build.zig +++ b/build.zig @@ -53,15 +53,17 @@ pub fn build(b: *std.Build) void { run_step.dependOn(&run_cmd.step); // Creates a step for unit testing. - const exe_tests = b.addTest(.{ + const unit_tests = b.addTest(.{ .root_source_file = .{ .path = "src/tests.zig" }, .target = target, .optimize = optimize, }); + const run_unit_tests = b.addRunArtifact(unit_tests); + // Similar to creating the run step earlier, this exposes a `test` step to // the `zig build --help` menu, providing a way for the user to request // running the unit tests. const test_step = b.step("test", "Run unit tests"); - test_step.dependOn(&exe_tests.step); + test_step.dependOn(&run_unit_tests.step); } diff --git a/src/HashArrayMappedTrie.zig b/src/HashArrayMappedTrie.zig deleted file mode 100644 index 7e1a99b..0000000 --- a/src/HashArrayMappedTrie.zig +++ /dev/null @@ -1,238 +0,0 @@ -//! Hash Array Mapped Trie -//! https://idea.popcount.org/2012-07-25-introduction-to-hamt/ -const std = @import("std"); -// const Token = @import("Token.zig"); - -const Allocator = std.mem.Allocator; -const HashArrayMappedTrie = @This(); - -const t = 5; -const table_size = std.math.powi(u32, 2, t) catch unreachable; - -root: [table_size]?*Node, -allocator: Allocator, - -const Node = union(enum) { kv: Pair, table: Table }; -const Pair = struct { key: []const u8, value: void }; -const Table = struct { map: u32 = 0, base: [*]Node }; - -pub fn init(allocator: Allocator) !HashArrayMappedTrie { - return .{ - .root = [_]?*Node{null} ** table_size, - .allocator = allocator, - }; -} - -pub fn deinit(self: *HashArrayMappedTrie) void { - for (self.root) |maybe_node| { - const node = maybe_node orelse continue; - - deinitInner(self.allocator, node); - self.allocator.destroy(node); - } -} - -fn deinitInner(allocator: Allocator, node: *Node) void { - switch (node.*) { - .kv => |_| return, // will be deallocated by caller - .table => |table| { - const amt_ptr = table.base[0..@popCount(table.map)]; // Array Mapped Table - - for (amt_ptr) |*sub_node| { - if (sub_node.* == .table) { - deinitInner(allocator, sub_node); - } - } - - allocator.free(amt_ptr); - }, - } -} - -fn amtIdx(comptime T: type, bitset: T, offset: u16) std.math.Log2Int(T) { - const L2I = std.math.Log2Int(T); - - const shift_amt = @intCast(L2I, @typeInfo(T).Int.bits - offset); - return @truncate(L2I, bitset >> shift_amt); -} - -pub fn search(self: *HashArrayMappedTrie, key: []const u8) ?Pair { - const bitset = hash(key); - - // most siginificant t bits from hash - var hash_offset: u5 = t; - var current: *Node = self.root[amtIdx(u32, bitset, hash_offset)] orelse return null; - - while (true) { - switch (current.*) { - .table => |table| { - const mask = @as(u32, 1) << amtIdx(u32, bitset, hash_offset); - - if (table.map & mask != 0) { - const idx = @popCount(table.map & (mask - 1)); - current = &table.base[idx]; - - hash_offset += t; - } else return null; // hash table entry is empty - }, - .kv => |pair| { - if (!std.mem.eql(u8, pair.key, key)) return null; - return pair; - }, - } - } -} - -pub fn insert(self: *HashArrayMappedTrie, comptime key: []const u8, value: void) !void { - const bitset = hash(key); - - // most siginificant t bits from hash - var hash_offset: u5 = t; - const root_idx = amtIdx(u32, bitset, hash_offset); - - var current: *Node = self.root[root_idx] orelse { - // node in root table is empty, place the KV here - const node = try self.allocator.create(Node); - node.* = .{ .kv = .{ .key = key, .value = value } }; - - self.root[root_idx] = node; - return; - }; - - while (true) { - const mask = @as(u32, 1) << amtIdx(u32, bitset, hash_offset); - - switch (current.*) { - .table => |*table| { - if (table.map & mask == 0) { - // Empty - const old_len = @popCount(table.map); - const new_base = try self.allocator.alloc(Node, old_len + 1); - const new_map = table.map | mask; - - var i: u5 = 0; - for (0..@typeInfo(u32).Int.bits) |shift| { - const mask_loop = @as(u32, 1) << @intCast(u5, shift); - - if (new_map & mask_loop != 0) { - defer i += 1; - - const idx = @popCount(table.map & (mask_loop - 1)); - const copy = if (mask == mask_loop) Node{ .kv = Pair{ .key = key, .value = value } } else table.base[idx]; - new_base[i] = copy; - } - } - - self.allocator.free(table.base[0..old_len]); - table.base = new_base.ptr; - table.map = new_map; - - return; // inserted an elemnt into the Trie - } else { - // Found an entry in the array, continue loop (?) - const idx = @popCount(table.map & (mask - 1)); - current = &table.base[idx]; - - hash_offset += t; // Go one layer deper - } - }, - .kv => |prev_pair| { - const prev_bitset = hash(prev_pair.key); - const prev_mask = @as(u32, 1) << amtIdx(u32, prev_bitset, hash_offset); - - switch (std.math.order(mask, prev_mask)) { - .lt, .gt => { - // there are no collisions between the two hash subsets. - const pairs = try self.allocator.alloc(Node, 2); - const map = mask | prev_mask; - - pairs[@popCount(map & (prev_mask - 1))] = .{ .kv = prev_pair }; - pairs[@popCount(map & (mask - 1))] = .{ .kv = .{ .key = key, .value = value } }; - - current.* = .{ .table = .{ .map = map, .base = pairs.ptr } }; - return; - }, - .eq => { - const copied_pair = try self.allocator.alloc(Node, 1); - copied_pair[0] = .{ .kv = prev_pair }; - - current.* = .{ .table = .{ .map = mask, .base = copied_pair.ptr } }; - }, - } - }, - } - } -} - -pub fn walk(hamt: *const HashArrayMappedTrie) void { - for (hamt.root, 0..) |maybe_node, i| { - std.debug.print("{:0>2}: ", .{i}); - - if (maybe_node == null) { - std.debug.print("null\n", .{}); - } else { - recurseNodes(maybe_node.?, 1); - } - } -} - -fn recurseNodes(node: *Node, depth: u16) void { - switch (node.*) { - .kv => |pair| { - std.debug.print(".{{ .key = \"{s}\", .value = {} }}\n", .{ pair.key, pair.value }); - }, - .table => |table| { - std.debug.print(".{{ .map = 0x{X:0>8}, .ptr = {*} }}\n", .{ table.map, table.base }); - - for (0..@popCount(table.map)) |i| { - for (0..depth) |_| std.debug.print(" ", .{}); - std.debug.print("{:0>2}: ", .{i}); - - recurseNodes(&table.base[i], depth + 1); - } - }, - } -} - -fn hash(key: []const u8) u32 { - var result: u32 = 0; - - // 6 because we're working with 'a' -> 'z' - for (key) |c| result |= @as(u32, 1) << 6 + @intCast(u5, c - 'a'); - - return result; -} - -test "insert doesn't panic" { - var trie = try HashArrayMappedTrie.init(std.testing.allocator); - defer trie.deinit(); - - try trie.insert("hello", {}); -} - -test "search doesn't panic" { - var trie = try HashArrayMappedTrie.init(std.testing.allocator); - defer trie.deinit(); - - std.debug.assert(trie.search("hello") == null); -} - -test "insert edge cases" { - var trie = try HashArrayMappedTrie.init(std.heap.page_allocator); - defer trie.deinit(); - - // Basic Usage - try trie.insert("hello", {}); - try trie.insert("world", {}); - try trie.insert("zywxv", {}); - - // Colliding Keys - try trie.insert("abcde", {}); - try trie.insert("abcdef", {}); - - trie.walk(); - - try std.testing.expectEqual(Pair{ .key = "hello", .value = {} }, trie.search("hello").?); - - // try std.testing.expectEqual(Pair{ .key = "abcde", .value = {} }, trie.search("abcde").?); -} diff --git a/src/main.zig b/src/main.zig index 4feec08..e2157a4 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,5 +1,5 @@ const std = @import("std"); -const HashArrayMappedTrie = @import("HashArrayMappedTrie.zig"); +const HashArrayMappedTrie = @import("trie.zig").HashArrayMappedTrie; pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; @@ -7,25 +7,47 @@ pub fn main() !void { const allocator = gpa.allocator(); - var trie = try HashArrayMappedTrie.init(allocator); - defer trie.deinit(); + var trie = HashArrayMappedTrie([]const u8, void, Context(u32)).init(); + defer trie.deinit(allocator); - try trie.insert("and", {}); - try trie.insert("class", {}); - try trie.insert("else", {}); - try trie.insert("false", {}); - try trie.insert("for", {}); - try trie.insert("fun", {}); - try trie.insert("if", {}); - try trie.insert("nil", {}); - try trie.insert("or", {}); - try trie.insert("print", {}); - try trie.insert("return", {}); - try trie.insert("super", {}); - try trie.insert("this", {}); - try trie.insert("true", {}); - try trie.insert("var", {}); - try trie.insert("while", {}); + try trie.insert(allocator, "and", {}); + try trie.insert(allocator, "class", {}); + try trie.insert(allocator, "else", {}); + try trie.insert(allocator, "false", {}); + try trie.insert(allocator, "for", {}); + try trie.insert(allocator, "fun", {}); + try trie.insert(allocator, "if", {}); + try trie.insert(allocator, "nil", {}); + try trie.insert(allocator, "or", {}); + try trie.insert(allocator, "print", {}); + try trie.insert(allocator, "return", {}); + try trie.insert(allocator, "super", {}); + try trie.insert(allocator, "this", {}); + try trie.insert(allocator, "true", {}); + try trie.insert(allocator, "var", {}); + try trie.insert(allocator, "while", {}); - trie.walk(); + try trie.print(); +} + +pub fn Context(comptime HashCode: type) type { + const Log2Int = std.math.Log2Int; + + return struct { + pub const Digest = HashCode; + + pub inline fn hash(key: []const u8) Digest { + // the MSB will represent 'z' + const offset = @typeInfo(Digest).Int.bits - 26; + + var result: Digest = 0; + for (key) |c| result |= @as(Digest, 1) << @intCast(Log2Int(Digest), offset + c - 'a'); + + return result; + } + + pub inline fn eql(left: []const u8, right: []const u8) bool { + return std.mem.eql(u8, left, right); + } + }; } diff --git a/src/tests.zig b/src/tests.zig index 76bb99b..f6acc18 100644 --- a/src/tests.zig +++ b/src/tests.zig @@ -1,6 +1,5 @@ comptime { - _ = @import("HashArrayMappedTrie.zig"); - _ = @import("main.zig"); + _ = @import("trie.zig"); } test { diff --git a/src/trie.zig b/src/trie.zig new file mode 100644 index 0000000..74e4f48 --- /dev/null +++ b/src/trie.zig @@ -0,0 +1,266 @@ +const std = @import("std"); + +const Allocator = std.mem.Allocator; +const Log2Int = std.math.Log2Int; + +/// Hash Array Mapped Trie +/// https://idea.popcount.org/2012-07-25-introduction-to-hamt/ +pub fn HashArrayMappedTrie(comptime K: type, comptime V: type, comptime Context: type) type { + return struct { + const Self = @This(); + + const Digest = Context.Digest; // as in Hash Code or Hash Digest + const table_size = @typeInfo(Digest).Int.bits; + const t = @intCast(Log2Int(Digest), @typeInfo(Log2Int(Digest)).Int.bits); + + root: [table_size]?*Node, + + const Node = union(enum) { kv: Pair, table: Table }; + const Table = struct { map: u32 = 0, base: [*]Node }; + const Pair = struct { key: K, value: V }; + + pub fn init() Self { + return Self{ .root = [_]?*Node{null} ** table_size }; + } + + pub fn deinit(self: *Self, allocator: Allocator) void { + for (self.root) |maybe_node| { + const node = maybe_node orelse continue; + + _deinit(allocator, node); + allocator.destroy(node); + } + } + + fn _deinit(allocator: Allocator, node: *Node) void { + switch (node.*) { + .kv => |_| return, // will be deallocated by caller + .table => |table| { + const amt_ptr = table.base[0..@popCount(table.map)]; // Array Mapped Table + + for (amt_ptr) |*sub_node| { + if (sub_node.* == .table) { + _deinit(allocator, sub_node); + } + } + + allocator.free(amt_ptr); + }, + } + } + + fn tableIdx(hash: Digest, offset: u16) Log2Int(Digest) { + const shift_amt = @intCast(Log2Int(Digest), table_size - offset); + + return @truncate(Log2Int(Digest), hash >> shift_amt); + } + + pub fn search(self: *Self, key: K) ?Pair { + const hash = Context.hash(key); + + // most siginificant t bits from hash + var hash_offset: Log2Int(Digest) = t; + var current: *Node = self.root[tableIdx(hash, hash_offset)] orelse return null; + + while (true) { + switch (current.*) { + .table => |table| { + const mask = @as(Digest, 1) << tableIdx(hash, hash_offset); + + if (table.map & mask != 0) { + const idx = @popCount(table.map & (mask - 1)); + current = &table.base[idx]; + + hash_offset += t; + } else return null; // hash table entry is empty + }, + .kv => |pair| { + if (!Context.eql(pair.key, key)) return null; + return pair; + }, + } + } + } + + pub fn insert(self: *Self, allocator: Allocator, comptime key: K, value: V) !void { + const hash = Context.hash(key); + + // most siginificant t bits from hash + var hash_offset: Log2Int(Digest) = t; + const root_idx = tableIdx(hash, hash_offset); + + var current: *Node = self.root[root_idx] orelse { + // node in root table is empty, place the KV here + const node = try allocator.create(Node); + node.* = .{ .kv = .{ .key = key, .value = value } }; + + self.root[root_idx] = node; + return; + }; + + while (true) { + const mask = @as(Digest, 1) << tableIdx(hash, hash_offset); + + switch (current.*) { + .table => |*table| { + if (table.map & mask == 0) { + // Empty + const old_len = @popCount(table.map); + const new_base = try allocator.alloc(Node, old_len + 1); + const new_map = table.map | mask; + + var i: Log2Int(Digest) = 0; + for (0..table_size) |shift| { + const mask_loop = @as(Digest, 1) << @intCast(u5, shift); + + if (new_map & mask_loop != 0) { + defer i += 1; + + const idx = @popCount(table.map & (mask_loop - 1)); + const copy = if (mask == mask_loop) Node{ .kv = Pair{ .key = key, .value = value } } else table.base[idx]; + new_base[i] = copy; + } + } + + allocator.free(table.base[0..old_len]); + table.base = new_base.ptr; + table.map = new_map; + + return; // inserted an elemnt into the Trie + } else { + // Found an entry in the array, continue loop (?) + const idx = @popCount(table.map & (mask - 1)); + current = &table.base[idx]; + + hash_offset += t; // Go one layer deper + } + }, + .kv => |prev_pair| { + const prev_hash = Context.hash(prev_pair.key); + const prev_mask = @as(Digest, 1) << tableIdx(prev_hash, hash_offset); + + switch (std.math.order(mask, prev_mask)) { + .lt, .gt => { + // there are no collisions between the two hash subsets. + const pairs = try allocator.alloc(Node, 2); + const map = mask | prev_mask; + + pairs[@popCount(map & (prev_mask - 1))] = .{ .kv = prev_pair }; + pairs[@popCount(map & (mask - 1))] = .{ .kv = .{ .key = key, .value = value } }; + + current.* = .{ .table = .{ .map = map, .base = pairs.ptr } }; + return; + }, + .eq => { + const copied_pair = try allocator.alloc(Node, 1); + copied_pair[0] = .{ .kv = prev_pair }; + + current.* = .{ .table = .{ .map = mask, .base = copied_pair.ptr } }; + }, + } + }, + } + } + } + + pub fn print(self: *Self) !void { + const stdout = std.io.getStdOut().writer(); + var buffered = std.io.bufferedWriter(stdout); + + const w = buffered.writer(); + + for (self.root, 0..) |maybe_node, i| { + try w.print("{:0>2}: ", .{i}); + + if (maybe_node) |node| { + try _print(w, node, 1); + } else { + try w.print("null\n", .{}); + } + } + + try buffered.flush(); + } + + fn _print(w: anytype, node: *Node, depth: u16) !void { + // @compileLog(@TypeOf(w)); + + switch (node.*) { + .kv => |pair| { + try w.print(".{{ .key = \"{s}\", .value = {} }}\n", .{ pair.key, pair.value }); + }, + .table => |table| { + try w.print(".{{ .map = 0x{X:0>8}, .ptr = {*} }}\n", .{ table.map, table.base }); + + for (0..@popCount(table.map)) |i| { + for (0..depth) |_| try w.print(" ", .{}); + try w.print("{:0>2}: ", .{i}); + + try _print(w, &table.base[i], depth + 1); + } + }, + } + } + }; +} + +fn TestContext(comptime HashCode: type) type { + return struct { + pub const Digest = HashCode; + + pub inline fn hash(key: []const u8) Digest { + // the MSB will represent 'z' + const offset = @typeInfo(Digest).Int.bits - 26; + + var result: Digest = 0; + for (key) |c| result |= @as(Digest, 1) << @intCast(Log2Int(Digest), offset + c - 'a'); + + return result; + } + + pub inline fn eql(left: []const u8, right: []const u8) bool { + return std.mem.eql(u8, left, right); + } + }; +} + +const TestHamt = HashArrayMappedTrie([]const u8, void, TestContext(u32)); + +test "trie init" { + _ = TestHamt.init(); +} + +test "init and deinit" { + const allocator = std.testing.allocator; + + var trie = TestHamt.init(); + defer trie.deinit(allocator); +} + +test "trie insert" { + const allocator = std.testing.allocator; + + var trie = TestHamt.init(); + defer trie.deinit(allocator); + + try trie.insert(allocator, "hello", {}); + try trie.insert(allocator, "world", {}); +} + +test "trie search" { + const Pair = TestHamt.Pair; + const allocator = std.testing.allocator; + + var trie = TestHamt.init(); + defer trie.deinit(allocator); + + try std.testing.expectEqual(@as(?Pair, null), trie.search("sdvx")); + + try trie.insert(allocator, "sdvx", {}); + + try std.testing.expectEqual(@as(?Pair, .{ .key = "sdvx", .value = {} }), trie.search("sdvx")); + try std.testing.expectEqual(@as(?Pair, null), trie.search("")); + + try trie.insert(allocator, "", {}); + try std.testing.expectEqual(@as(?Pair, .{ .key = "", .value = {} }), trie.search("")); +}