From 77d11c9903d7f4a3a31ef59fc954421d7dfa0eb1 Mon Sep 17 00:00:00 2001 From: Rekai Musuka Date: Tue, 23 May 2023 22:45:39 -0500 Subject: [PATCH] fix: replace channel impl --- src/lib.zig | 229 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 155 insertions(+), 74 deletions(-) diff --git a/src/lib.zig b/src/lib.zig index 711b09d..958782c 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1,106 +1,187 @@ const std = @import("std"); const Log2Int = std.math.Log2Int; +const Allocator = std.mem.Allocator; -const EmuMessage = enum { Pause, Resume, Quit }; -const GuiMessage = enum { Paused, Quit }; +// TODO: Rewrite +// pub const TwoWayChannel = struct { +// const Self = @This(); -pub const TwoWayChannel = struct { - const Self = @This(); +// emu: Channel(EmuMessage), +// gui: Channel(GuiMessage), - emu: Channel(EmuMessage), - gui: Channel(GuiMessage), +// pub fn init(items: []u8) Self { +// comptime std.debug.assert(@sizeOf(EmuMessage) == @sizeOf(GuiMessage)); +// comptime std.debug.assert(@sizeOf(@typeInfo([]u8).Pointer.child) == @sizeOf(EmuMessage)); - pub fn init(items: []u8) Self { - comptime std.debug.assert(@sizeOf(EmuMessage) == @sizeOf(GuiMessage)); - comptime std.debug.assert(@sizeOf(@typeInfo([]u8).Pointer.child) == @sizeOf(EmuMessage)); +// std.debug.assert(items.len % 2 == 0); - std.debug.assert(items.len % 2 == 0); +// const left = @ptrCast([*]EmuMessage, items)[0 .. items.len / 2]; +// const right = @ptrCast([*]GuiMessage, items)[items.len / 2 .. items.len]; - const left = @ptrCast([*]EmuMessage, items)[0 .. items.len / 2]; - const right = @ptrCast([*]GuiMessage, items)[items.len / 2 .. items.len]; +// return .{ .emu = Channel(EmuMessage).init(left), .gui = Channel(GuiMessage).init(right) }; +// } +// }; - return .{ .emu = Channel(EmuMessage).init(left), .gui = Channel(GuiMessage).init(right) }; - } -}; - -fn Channel(comptime T: type) type { +pub fn Channel(comptime T: type, comptime N: usize) type { return struct { - const Self = @This(); const Index = usize; + const capacity_limit = (@as(Index, 1) << @typeInfo(Index).Int.bits - 1) - 1; // half the range of index type - const Atomic = std.atomic.Atomic; + tx: Sender, + rx: Receiver, - const max_capacity = (@as(Index, 1) << @typeInfo(Index).Int.bits - 1) - 1; // half the range of index type - const log = std.log.scoped(.Channel); + pub const Sender = struct { + const Self = @This(); - read: Atomic(Index), - write: Atomic(Index), - buf: []T, + read: *Index, + write: *Index, + ptr: *[N]T, - const Error = error{buffer_full}; + const Error = error{buffer_full}; - pub fn init(buf: []T) Self { - std.debug.assert(std.math.isPowerOfTwo(buf.len)); // capacity must be a power of two - std.debug.assert(buf.len <= max_capacity); + pub fn send(self: Self, value: T) void { + const idx_r = @atomicLoad(Index, self.read, .Acquire); + const idx_w = @atomicLoad(Index, self.write, .Acquire); + + // Check to see if Queue is full + if (idx_w - idx_r == N) @panic("Channel: Buffer is full"); + + self.ptr[mask(idx_w)] = value; + + std.atomic.fence(.Release); + @atomicStore(Index, self.write, idx_w + 1, .Release); + } + + pub fn len(self: Self) Index { + const idx_r = @atomicLoad(Index, self.read, .Acquire); + const idx_w = @atomicLoad(Index, self.write, .Acquire); + + return idx_w - idx_r; + } + }; + + pub const Receiver = struct { + const Self = @This(); + + read: *Index, + write: *Index, + ptr: *[N]T, + + pub fn recv(self: Self) ?T { + const idx_r = @atomicLoad(Index, self.read, .Acquire); + const idx_w = @atomicLoad(Index, self.write, .Acquire); + + if (idx_r == idx_w) return null; + + std.atomic.fence(.Acquire); + const value = self.ptr[mask(idx_r)]; + + std.atomic.fence(.Release); + @atomicStore(Index, self.read, idx_r + 1, .Release); + + return value; + } + + pub fn peek(self: Self) ?T { + const idx_r = @atomicLoad(Index, self.read, .Acquire); + const idx_w = @atomicLoad(Index, self.write, .Acquire); + + if (idx_r == idx_w) return null; + + std.atomic.fence(.Acquire); + return self.ptr[mask(idx_r)]; + } + + pub fn len(self: Self) Index { + const idx_r = @atomicLoad(Index, self.read, .Acquire); + const idx_w = @atomicLoad(Index, self.write, .Acquire); + + return idx_w - idx_r; + } + }; + + fn mask(idx: Index) Index { + return idx & (@intCast(Index, N) - 1); + } + + pub fn init(allocator: Allocator) !Channel(T, N) { + const buf = try allocator.alloc(T, N); + const indicies = try allocator.alloc(Index, 2); return .{ - .read = Atomic(Index).init(0), - .write = Atomic(Index).init(0), - .buf = buf, + .tx = Sender{ + .ptr = buf[0..N], + .read = &indicies[0], + .write = &indicies[1], + }, + .rx = Receiver{ + .ptr = buf[0..N], + .read = &indicies[0], + .write = &indicies[1], + }, }; } - pub fn push(self: *Self, value: T) void { - const read_idx = self.read.load(.Acquire); - const write_idx = self.write.load(.Acquire); + pub fn deinit(self: *Channel(T, N), allocator: Allocator) void { + const indicies: []Index = @ptrCast([*]Index, self.tx.read)[0..2]; - // Check to see if Queue is full - if (write_idx - read_idx == self.buf.len) @panic("Channel: Buffer is full"); + allocator.free(indicies); + allocator.free(self.tx.ptr); - self.buf[self.mask(write_idx)] = value; - - std.atomic.fence(.Release); - self.write.store(write_idx + 1, .Release); + self.* = undefined; } - pub fn pop(self: *Self) ?T { - const read_idx = self.read.load(.Acquire); - const write_idx = self.write.load(.Acquire); - - if (read_idx == write_idx) return null; - - std.atomic.fence(.Acquire); - const value = self.buf[self.mask(read_idx)]; - - std.atomic.fence(.Release); - self.read.store(read_idx + 1, .Release); - - return value; - } - - pub fn peek(self: *const Self) ?T { - const read_idx = self.read.load(.Acquire); - const write_idx = self.write.load(.Acquire); - - if (read_idx == write_idx) return null; - - std.atomic.fence(.Acquire); - return self.buf[self.mask(read_idx)]; - } - - pub fn len(self: *const Self) Index { - const read_idx = self.read.load(.Acquire); - const write_idx = self.write.load(.Acquire); - - return write_idx - read_idx; - } - - fn mask(self: *const Self, idx: Index) Index { - return idx & (self.buf.len - 1); + comptime { + std.debug.assert(std.math.isPowerOfTwo(N)); + std.debug.assert(N <= capacity_limit); } }; } +test "Channel init + deinit" { + var ch = try Channel(u8, 64).init(std.testing.allocator); + defer ch.deinit(std.testing.allocator); +} + +test "Channel basic queue" { + var ch = try Channel(u8, 64).init(std.testing.allocator); + defer ch.deinit(std.testing.allocator); + + ch.tx.send(128); + + try std.testing.expectEqual(@as(?u8, 128), ch.rx.recv()); +} + +test "Channel basic multithreaded" { + const builtin = @import("builtin"); + + if (builtin.single_threaded) + return error.SkipZigTest; + + const run_tx = struct { + fn run(tx: anytype) void { + tx.send(128); + } + }.run; + + const run_rx = struct { + fn run(rx: anytype) !void { + while (rx.recv()) |value| { + try std.testing.expectEqual(@as(?u8, 128), value); + } + } + }.run; + + var ch = try Channel(u8, 64).init(std.testing.allocator); + defer ch.deinit(std.testing.allocator); + + const tx_handle = try std.Thread.spawn(.{}, run_tx, .{&ch.tx}); + defer tx_handle.join(); + + const rx_handle = try std.Thread.spawn(.{}, run_rx, .{&ch.rx}); + defer rx_handle.join(); +} + pub fn RingBuffer(comptime T: type) type { return struct { const Self = @This();