477 lines
19 KiB
Zig
477 lines
19 KiB
Zig
const std = @import("std");
|
|
|
|
const Allocator = std.mem.Allocator;
|
|
const Log2Int = std.math.Log2Int;
|
|
|
|
/// Hash Array Mapped Trie
|
|
/// https://idea.popcount.org/2012-07-25-introduction-to-hamt/
|
|
pub fn HashArrayMappedTrie(comptime K: type, comptime V: type, comptime Context: type) type {
|
|
// zig fmt: off
|
|
comptime { verify(K, Context); }
|
|
// zig fmt: on
|
|
|
|
return struct {
|
|
const Self = @This();
|
|
|
|
const Digest = Context.Digest; // as in Hash Code or Hash Digest
|
|
const table_size = @typeInfo(Digest).Int.bits;
|
|
const t: Log2Int(Digest) = @intCast(@typeInfo(Log2Int(Digest).Int.bits));
|
|
|
|
free_list: FreeList,
|
|
root: []?*Node,
|
|
|
|
const Node = union(enum) { kv: Pair, table: Table };
|
|
const Table = struct { map: Digest = 0, base: [*]Node };
|
|
pub const Pair = struct { key: K, value: V };
|
|
|
|
/// Responsible for managing HAMT Memory
|
|
const FreeList = struct {
|
|
list: *[table_size]?FreeList.Node,
|
|
|
|
// The index of the array of the linked list any given node belongs to
|
|
// informs us of how many elements there are in the [*]Node ptr.
|
|
const Node = struct {
|
|
inner: [*]Self.Node,
|
|
next: ?*FreeList.Node = null,
|
|
|
|
pub fn deinit(self: *const FreeList.Node, allocator: Allocator, len: usize) void {
|
|
switch (len) {
|
|
0 => unreachable,
|
|
1 => allocator.destroy(@as(*Self.Node, @ptrCast(self.inner))),
|
|
else => allocator.free(self.inner[0..len]),
|
|
}
|
|
}
|
|
};
|
|
|
|
pub fn init(allocator: Allocator) !FreeList {
|
|
const list = try allocator.create([table_size]?FreeList.Node);
|
|
@memset(list, null);
|
|
|
|
return .{ .list = list };
|
|
}
|
|
|
|
pub fn deinit(self: *FreeList, allocator: Allocator) void {
|
|
for (self.list, 0..) |maybe_node, i| {
|
|
// the nodes that exist within the array `self.list` are freed outside
|
|
// of this `for` loop, so if any given `maybe_node` is a linked list that is
|
|
// 0 or 1 elements long, there is no thing to do here.
|
|
const len = i + 1;
|
|
|
|
var current: *FreeList.Node = blk: {
|
|
const head = maybe_node orelse continue; // skip if list is 0 elements long
|
|
|
|
head.deinit(allocator, len); // while we know the head exists, free the memory it points to
|
|
break :blk head.next orelse continue; // skip if list is 1 element long (see above comment)
|
|
};
|
|
|
|
while (current.next) |next| {
|
|
const next_ptr = next; // copy the pointer 'cause we're about to deallocate it's owner
|
|
|
|
current.deinit(allocator, len);
|
|
allocator.destroy(current);
|
|
|
|
current = next_ptr;
|
|
}
|
|
|
|
current.deinit(allocator, len); // free the tail of the list
|
|
allocator.destroy(current);
|
|
}
|
|
|
|
allocator.destroy(self.list);
|
|
self.* = undefined;
|
|
}
|
|
|
|
pub fn alloc(self: *FreeList, allocator: Allocator, comptime T: type, len: usize) ![]T {
|
|
if (len == 0 or len > table_size) return error.unexpected_table_length;
|
|
|
|
// If head is null, (head is self.list[len - 1]) then there was nothing in the free list
|
|
// therefore we should use the backup allocator
|
|
var current: *FreeList.Node = &(self.list[len - 1] orelse return try allocator.alloc(T, len));
|
|
var prev: ?*FreeList.Node = null;
|
|
|
|
while (current.next) |next| {
|
|
prev = current;
|
|
current = next;
|
|
}
|
|
|
|
const ret_ptr = current.inner;
|
|
|
|
if (current == &self.list[len - 1].?) {
|
|
// The current node is also the head, meaning that there's only one
|
|
// element in this linked list. Nodes in self.list are deallocated by another
|
|
// part of the program, so we just want to set the ?FreeList.Node to null
|
|
self.list[len - 1] = null;
|
|
} else {
|
|
std.debug.assert(prev != null); // this is invaraibly true if current != the head node
|
|
std.debug.assert(prev.?.next == current); // FIXME: is this ptr comparison even valuable?
|
|
|
|
prev.?.next = null; // remove node from linked list
|
|
allocator.destroy(current);
|
|
}
|
|
|
|
// this is safe because we've grabbed this many-ptr from the linked list of AMTs that have this size
|
|
return ret_ptr[0..len];
|
|
}
|
|
|
|
pub fn create(self: *FreeList, allocator: Allocator, comptime T: type) !*T {
|
|
return @ptrCast(try self.alloc(allocator, T, 1));
|
|
}
|
|
|
|
/// Free'd nodes aren't deallocated, but instead are tracked by a free list where they
|
|
/// may be reused in the future
|
|
///
|
|
/// We may allocate to append a new FreeList Node to the end of the Linked List
|
|
pub fn free(self: *FreeList, allocator: Allocator, ptr: []Self.Node) !void {
|
|
if (ptr.len == 0 or ptr.len > table_size) return error.unexpected_table_length;
|
|
|
|
var current: *FreeList.Node = &(self.list[ptr.len - 1] orelse {
|
|
// There were no nodes present so start off the linked list
|
|
self.list[ptr.len - 1] = .{ .inner = ptr.ptr };
|
|
return;
|
|
});
|
|
|
|
// traverse the linked list
|
|
while (current.next) |next| current = next;
|
|
|
|
const tail = try allocator.create(FreeList.Node);
|
|
tail.* = .{ .inner = ptr.ptr };
|
|
|
|
current.next = tail;
|
|
}
|
|
|
|
pub fn destroy(self: *FreeList, allocator: Allocator, node: *Self.Node) !void {
|
|
self.free(allocator, @as([*]Self.Node, @ptrCast(node))[0..1]);
|
|
}
|
|
};
|
|
|
|
pub fn init(allocator: Allocator) !Self {
|
|
// TODO: Add ability to have a larger root node (for quicker lookup times)
|
|
const root = try allocator.alloc(?*Node, table_size);
|
|
@memset(root, null);
|
|
|
|
return Self{ .root = root, .free_list = try FreeList.init(allocator) };
|
|
}
|
|
|
|
pub fn deinit(self: *Self, allocator: Allocator) void {
|
|
for (self.root) |maybe_node| {
|
|
const node = maybe_node orelse continue;
|
|
|
|
_deinit(allocator, node);
|
|
allocator.destroy(node);
|
|
}
|
|
|
|
allocator.free(self.root);
|
|
self.free_list.deinit(allocator);
|
|
}
|
|
|
|
fn _deinit(allocator: Allocator, node: *Node) void {
|
|
switch (node.*) {
|
|
.kv => |_| return, // will be deallocated by caller
|
|
.table => |table| {
|
|
const amt_ptr = table.base[0..@popCount(table.map)]; // Array Mapped Table
|
|
|
|
for (amt_ptr) |*sub_node| {
|
|
if (sub_node.* == .table) {
|
|
_deinit(allocator, sub_node);
|
|
}
|
|
}
|
|
|
|
allocator.free(amt_ptr);
|
|
},
|
|
}
|
|
}
|
|
|
|
fn tableIdx(hash: Digest, offset: u16) Log2Int(Digest) {
|
|
const shift_amt: Log2Int(Digest) = @intCast(table_size - offset);
|
|
|
|
return @truncate(hash >> shift_amt);
|
|
}
|
|
|
|
pub fn search(self: *Self, key: K) ?Pair {
|
|
const hash = Context.hash(key);
|
|
|
|
// most siginificant t bits from hash
|
|
var hash_offset: Log2Int(Digest) = t;
|
|
var current: *Node = self.root[tableIdx(hash, hash_offset)] orelse return null;
|
|
|
|
while (true) {
|
|
switch (current.*) {
|
|
.table => |table| {
|
|
const mask = @as(Digest, 1) << tableIdx(hash, hash_offset);
|
|
|
|
if (table.map & mask != 0) {
|
|
const idx = @popCount(table.map & (mask - 1));
|
|
current = &table.base[idx];
|
|
|
|
hash_offset += t;
|
|
} else return null; // hash table entry is empty
|
|
},
|
|
.kv => |pair| {
|
|
if (!Context.eql(pair.key, key)) return null;
|
|
return pair;
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn insert(self: *Self, allocator: Allocator, key: K, value: V) !void {
|
|
const hash = Context.hash(key);
|
|
|
|
// most siginificant t bits from hash
|
|
var hash_offset: Log2Int(Digest) = t;
|
|
const root_idx = tableIdx(hash, hash_offset);
|
|
|
|
var current: *Node = self.root[root_idx] orelse {
|
|
// node in root table is empty, place the KV here
|
|
const node = try self.free_list.create(allocator, Node);
|
|
node.* = .{ .kv = .{ .key = key, .value = value } };
|
|
|
|
self.root[root_idx] = node;
|
|
return;
|
|
};
|
|
|
|
while (true) {
|
|
const mask = @as(Digest, 1) << tableIdx(hash, hash_offset);
|
|
|
|
switch (current.*) {
|
|
.table => |*table| {
|
|
if (table.map & mask == 0) {
|
|
// Empty
|
|
const old_len = @popCount(table.map);
|
|
const new_base = try self.free_list.alloc(allocator, Node, old_len + 1);
|
|
const new_map = table.map | mask;
|
|
|
|
var i: Log2Int(Digest) = 0;
|
|
for (0..table_size) |shift| {
|
|
const mask_loop = @as(Digest, 1) << @as(Log2Int(Digest), @intCast(shift));
|
|
|
|
if (new_map & mask_loop != 0) {
|
|
defer i += 1;
|
|
|
|
const idx = @popCount(table.map & (mask_loop - 1));
|
|
const copy = if (mask == mask_loop) Node{ .kv = Pair{ .key = key, .value = value } } else table.base[idx];
|
|
new_base[i] = copy;
|
|
}
|
|
}
|
|
|
|
try self.free_list.free(allocator, table.base[0..old_len]);
|
|
table.base = new_base.ptr;
|
|
table.map = new_map;
|
|
|
|
return; // inserted an elemnt into the Trie
|
|
} else {
|
|
// Found an entry in the array, continue loop (?)
|
|
const idx = @popCount(table.map & (mask - 1));
|
|
current = &table.base[idx];
|
|
|
|
hash_offset += t; // Go one layer deper
|
|
}
|
|
},
|
|
.kv => |prev_pair| {
|
|
const prev_hash = Context.hash(prev_pair.key);
|
|
const prev_mask = @as(Digest, 1) << tableIdx(prev_hash, hash_offset);
|
|
|
|
switch (std.math.order(mask, prev_mask)) {
|
|
.lt, .gt => {
|
|
// there are no collisions between the two hash subsets.
|
|
const pairs = try self.free_list.alloc(allocator, Node, 2);
|
|
const map = mask | prev_mask;
|
|
|
|
pairs[@popCount(map & (prev_mask - 1))] = .{ .kv = prev_pair };
|
|
pairs[@popCount(map & (mask - 1))] = .{ .kv = .{ .key = key, .value = value } };
|
|
|
|
current.* = .{ .table = .{ .map = map, .base = pairs.ptr } };
|
|
return;
|
|
},
|
|
.eq => {
|
|
const copied_pair = try self.free_list.create(allocator, Node);
|
|
copied_pair.* = .{ .kv = prev_pair };
|
|
|
|
current.* = .{ .table = .{ .map = mask, .base = @as([*]Node, @ptrCast(copied_pair)) } };
|
|
},
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn print(self: *Self) !void {
|
|
const stdout = std.io.getStdOut().writer();
|
|
var buffered = std.io.bufferedWriter(stdout);
|
|
|
|
const w = buffered.writer();
|
|
|
|
for (self.root, 0..) |maybe_node, i| {
|
|
try w.print("{:0>2}: ", .{i});
|
|
|
|
if (maybe_node) |node| {
|
|
try _print(w, node, 1);
|
|
} else {
|
|
try w.print("null\n", .{});
|
|
}
|
|
}
|
|
|
|
try buffered.flush();
|
|
}
|
|
|
|
fn _print(w: anytype, node: *Node, depth: u16) !void {
|
|
// @compileLog(@TypeOf(w));
|
|
|
|
switch (node.*) {
|
|
.kv => |pair| {
|
|
try w.print(".{{ .key = \"{s}\", .value = {} }}\n", .{ pair.key, pair.value });
|
|
},
|
|
.table => |table| {
|
|
try w.print(".{{ .map = 0x{X:0>8}, .ptr = {*} }}\n", .{ table.map, table.base });
|
|
|
|
for (0..@popCount(table.map)) |i| {
|
|
for (0..depth) |_| try w.print(" ", .{});
|
|
try w.print("{:0>2}: ", .{i});
|
|
|
|
try _print(w, &table.base[i], depth + 1);
|
|
}
|
|
},
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn verify(comptime K: type, comptime Context: type) void {
|
|
// FIXME: Context should be able to be a pointer to a type
|
|
|
|
switch (@typeInfo(Context)) {
|
|
.Struct, .Union, .Enum => {},
|
|
.Pointer => @compileError("Pointer trie contexts have yet to be implemented"),
|
|
else => @compileError("Trie context must be a type with Digest, hash(" ++ @typeName(K) ++ ") Digest, and eql(" ++ @typeName(K) ++ ", " ++ @typeName(K) ++ ") bool"),
|
|
}
|
|
|
|
if (@hasDecl(Context, "Digest")) {
|
|
const Digest = Context.Digest;
|
|
const info = @typeInfo(Digest);
|
|
|
|
if (info != .Int) @compileError("Context.Digest must be an integer, however it was actually " ++ @typeName(Digest));
|
|
if (info.Int.signedness != .unsigned) @compileError("Context.Digest must be an unsigned integer, however it was actually an " ++ @typeName(Digest));
|
|
}
|
|
|
|
if (@hasDecl(Context, "hash")) {
|
|
const hash = Context.hash;
|
|
const HashFn = @TypeOf(hash);
|
|
|
|
const info = @typeInfo(HashFn);
|
|
|
|
if (info != .Fn) @compileError("Context.hash must be a function, however it was actually" ++ @typeName(HashFn));
|
|
|
|
const func = info.Fn;
|
|
if (func.params.len != 1) @compileError("Invalid Context.hash signature. Expected hash(" ++ @typeName(K) ++ "), but was actually " ++ @typeName(HashFn));
|
|
|
|
// short-circuiting guarantees no panics..............vvv here
|
|
if (func.params[0].type == null or func.params[0].type.? != K) {
|
|
const type_str = if (func.params[0].type) |Param| @typeName(Param) else "null";
|
|
@compileError("Invalid Context.hash signature. Parameter must be " ++ @typeName(K) ++ ", however it was " ++ type_str);
|
|
}
|
|
|
|
if (func.return_type == null or func.return_type.? != Context.Digest) {
|
|
const type_str = if (func.return_type) |Return| @typeName(Return) else "null";
|
|
|
|
@compileError("Invalid Context.hash signature. Return type must be " ++ @typeName(Context.Digest) ++ ", however it was " ++ type_str);
|
|
}
|
|
}
|
|
|
|
if (@hasDecl(Context, "eql")) {
|
|
const eql = Context.eql;
|
|
const EqlFn = @TypeOf(eql);
|
|
|
|
const info = @typeInfo(EqlFn);
|
|
|
|
if (info != .Fn) @compileError("Context.eql must be a function, however it was actually" ++ @typeName(EqlFn));
|
|
|
|
const func = info.Fn;
|
|
if (func.params.len != 2) @compileError("Invalid Context.eql signature. Expected eql(" ++ @typeName(K) ++ ", " ++ @typeName(K) ++ "), but was actually " ++ @typeName(EqlFn));
|
|
|
|
// short-circuiting guarantees no panics..............vvv here
|
|
if (func.params[0].type == null or func.params[0].type.? != K) {
|
|
const type_str = if (func.params[0].type) |Param| @typeName(Param) else "null";
|
|
@compileError("Invalid Context.eql signature. First parameter must be " ++ @typeName(K) ++ ", however it was " ++ type_str);
|
|
}
|
|
|
|
if (func.params[1].type == null or func.params[1].type.? != K) {
|
|
const type_str = if (func.params[1].type) |Param| @typeName(Param) else "null";
|
|
@compileError("Invalid Context.eql signature. Second parameter must be " ++ @typeName(K) ++ ", however it was " ++ type_str);
|
|
}
|
|
|
|
if (func.return_type == null or func.return_type.? != bool) {
|
|
const type_str = if (func.return_type) |Return| @typeName(Return) else "null";
|
|
|
|
@compileError("Invalid Context.eql signature, Return type must be " ++ @typeName(bool) ++ ", however it was " ++ type_str);
|
|
}
|
|
}
|
|
}
|
|
|
|
const StringContext = struct {
|
|
pub const Digest = u64;
|
|
|
|
pub inline fn hash(key: []const u8) Digest {
|
|
return std.hash.Wyhash.hash(0, key);
|
|
}
|
|
|
|
pub inline fn eql(left: []const u8, right: []const u8) bool {
|
|
return std.mem.eql(u8, left, right);
|
|
}
|
|
};
|
|
|
|
const StringTrie = HashArrayMappedTrie([]const u8, void, StringContext);
|
|
|
|
test "trie init" {
|
|
const allocator = std.testing.allocator;
|
|
var trie = try StringTrie.init(allocator);
|
|
defer trie.deinit(allocator);
|
|
}
|
|
|
|
test "init and deinit" {
|
|
const allocator = std.testing.allocator;
|
|
|
|
var trie = try StringTrie.init(allocator);
|
|
defer trie.deinit(allocator);
|
|
}
|
|
|
|
test "trie insert" {
|
|
const allocator = std.testing.allocator;
|
|
|
|
var trie = try StringTrie.init(allocator);
|
|
defer trie.deinit(allocator);
|
|
|
|
try trie.insert(allocator, "hello", {});
|
|
try trie.insert(allocator, "world", {});
|
|
}
|
|
|
|
test "trie search" {
|
|
const Pair = StringTrie.Pair;
|
|
const allocator = std.testing.allocator;
|
|
|
|
var trie = try StringTrie.init(allocator);
|
|
defer trie.deinit(allocator);
|
|
|
|
try std.testing.expectEqual(@as(?Pair, null), trie.search("sdvx"));
|
|
|
|
try trie.insert(allocator, "sdvx", {});
|
|
|
|
try std.testing.expectEqual(@as(?Pair, .{ .key = "sdvx", .value = {} }), trie.search("sdvx"));
|
|
try std.testing.expectEqual(@as(?Pair, null), trie.search(""));
|
|
|
|
try trie.insert(allocator, "", {});
|
|
try std.testing.expectEqual(@as(?Pair, .{ .key = "", .value = {} }), trie.search(""));
|
|
}
|
|
|
|
test "README.md example" {
|
|
const Pair = StringTrie.Pair;
|
|
const allocator = std.testing.allocator;
|
|
|
|
var trie = try StringTrie.init(allocator);
|
|
defer trie.deinit(allocator);
|
|
|
|
try trie.insert(allocator, "hello", {});
|
|
|
|
try std.testing.expectEqual(@as(?Pair, .{ .key = "hello", .value = {} }), trie.search("hello"));
|
|
try std.testing.expectEqual(@as(?Pair, null), trie.search("world"));
|
|
}
|