diff --git a/src/any.zig b/src/any.zig index 1f64253..90620f0 100644 --- a/src/any.zig +++ b/src/any.zig @@ -34,6 +34,10 @@ const unpackStruct = @import("struct.zig").unpackStruct; const packUnion = @import("union.zig").packUnion; const unpackUnion = @import("union.zig").unpackUnion; +const getEnumSize = @import("enum.zig").getEnumSize; +const packEnum = @import("enum.zig").packEnum; +const unpackEnum = @import("enum.zig").unpackEnum; + inline fn isString(comptime T: type) bool { switch (@typeInfo(T)) { .pointer => |ptr_info| { @@ -56,6 +60,7 @@ pub fn sizeOfPackedAny(comptime T: type, value: T) usize { .bool => return getBoolSize(), .int => return getIntSize(T, value), .float => return getFloatSize(T, value), + .@"enum" => return getEnumSize(T, value), .pointer => |ptr_info| { if (ptr_info.size == .Slice) { if (isString(T)) { @@ -105,6 +110,7 @@ pub fn packAny(writer: anytype, value: anytype) !void { }, .@"struct" => return packStruct(writer, T, value), .@"union" => return packUnion(writer, T, value), + .@"enum" => return packEnum(writer, T, value), .optional => { if (value) |val| { return packAny(writer, val); @@ -125,6 +131,7 @@ pub fn unpackAny(reader: anytype, allocator: std.mem.Allocator, comptime T: type .float => return unpackFloat(reader, T), .@"struct" => return unpackStruct(reader, allocator, T), .@"union" => return unpackUnion(reader, allocator, T), + .@"enum" => return unpackEnum(reader, T), .pointer => |ptr_info| { if (ptr_info.size == .slice) { if (isString(T)) { diff --git a/src/enum.zig b/src/enum.zig new file mode 100644 index 0000000..1ac4174 --- /dev/null +++ b/src/enum.zig @@ -0,0 +1,193 @@ +const std = @import("std"); + +const maybePackNull = @import("null.zig").maybePackNull; + +const getIntSize = @import("int.zig").getIntSize; +const packInt = @import("int.zig").packInt; +const unpackInt = @import("int.zig").unpackInt; + +inline fn assertEnumType(comptime T: type) type { + switch (@typeInfo(T)) { + .@"enum" => return T, + .optional => |opt_info| { + return assertEnumType(opt_info.child); + }, + else => @compileError("Expected enum, got " ++ @typeName(T)), + } +} + +pub fn getMaxEnumSize(comptime T: type) usize { + const Type = assertEnumType(T); + const tag_type = @typeInfo(Type).@"enum".tag_type; + return 1 + @sizeOf(tag_type); +} + +pub fn getEnumSize(comptime T: type, value: T) usize { + if (@typeInfo(T) == .optional) { + if (value) |v| { + return getEnumSize(@typeInfo(T).optional.child, v); + } else { + return 1; // size of null + } + } + + const tag_type = @typeInfo(T).@"enum".tag_type; + const int_value = @intFromEnum(value); + return getIntSize(tag_type, int_value); +} + +pub fn packEnum(writer: anytype, comptime T: type, value_or_maybe_null: T) !void { + const Type = assertEnumType(T); + const value: Type = try maybePackNull(writer, T, value_or_maybe_null) orelse return; + + const tag_type = @typeInfo(Type).@"enum".tag_type; + const int_value = @intFromEnum(value); + + try packInt(writer, tag_type, int_value); +} + +pub fn unpackEnum(reader: anytype, comptime T: type) !T { + const Type = assertEnumType(T); + const tag_type = @typeInfo(Type).@"enum".tag_type; + + // Construct the optional tag type to match T's optionality + const OptionalTagType = if (@typeInfo(T) == .optional) ?tag_type else tag_type; + + // Use unpackInt directly with the constructed optional tag type + const int_value = try unpackInt(reader, OptionalTagType); + + // Handle the optional case + if (@typeInfo(T) == .optional) { + if (int_value) |value| { + return @enumFromInt(value); + } else { + return null; + } + } else { + return @enumFromInt(int_value); + } +} + +test "getMaxEnumSize" { + const PlainEnum = enum { foo, bar }; + const U8Enum = enum(u8) { foo = 1, bar = 2 }; + const U16Enum = enum(u16) { foo, bar }; + + try std.testing.expectEqual(2, getMaxEnumSize(PlainEnum)); // u1 + header + try std.testing.expectEqual(2, getMaxEnumSize(U8Enum)); // u8 + header + try std.testing.expectEqual(3, getMaxEnumSize(U16Enum)); // u16 + header +} + +test "getEnumSize" { + const U8Enum = enum(u8) { foo = 0, bar = 150 }; + + try std.testing.expectEqual(1, getEnumSize(U8Enum, .foo)); // fits in positive fixint + try std.testing.expectEqual(2, getEnumSize(U8Enum, .bar)); // requires u8 format +} + +test "pack/unpack enum" { + const PlainEnum = enum { foo, bar }; + const U8Enum = enum(u8) { foo = 1, bar = 2 }; + const U16Enum = enum(u16) { alpha = 1000, beta = 2000 }; + + // Test plain enum + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try packEnum(buffer.writer(), PlainEnum, .bar); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), PlainEnum); + try std.testing.expectEqual(PlainEnum.bar, result); + } + + // Test enum(u8) + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try packEnum(buffer.writer(), U8Enum, .bar); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), U8Enum); + try std.testing.expectEqual(U8Enum.bar, result); + } + + // Test enum(u16) + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try packEnum(buffer.writer(), U16Enum, .alpha); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), U16Enum); + try std.testing.expectEqual(U16Enum.alpha, result); + } +} + + +test "enum edge cases" { + // Test enum with explicit and auto values + const MixedEnum = enum(u8) { + first = 10, + second, // auto-assigned to 11 + third = 20, + fourth, // auto-assigned to 21 + }; + + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try packEnum(buffer.writer(), MixedEnum, .second); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), MixedEnum); + try std.testing.expectEqual(MixedEnum.second, result); + try std.testing.expectEqual(11, @intFromEnum(result)); +} + +test "optional enum" { + const TestEnum = enum(u8) { foo = 1, bar = 2 }; + const OptionalEnum = ?TestEnum; + + // Test non-null optional enum + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + const value: OptionalEnum = .bar; + try packEnum(buffer.writer(), OptionalEnum, value); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), OptionalEnum); + try std.testing.expectEqual(@as(OptionalEnum, .bar), result); + } + + // Test null optional enum + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + const value: OptionalEnum = null; + try packEnum(buffer.writer(), OptionalEnum, value); + + var stream = std.io.fixedBufferStream(buffer.items); + const result = try unpackEnum(stream.reader(), OptionalEnum); + try std.testing.expectEqual(@as(OptionalEnum, null), result); + } +} + +test "getEnumSize with optional" { + const TestEnum = enum(u8) { foo = 0, bar = 150 }; + const OptionalEnum = ?TestEnum; + + // Test non-null optional enum size + const value: OptionalEnum = .bar; + try std.testing.expectEqual(2, getEnumSize(OptionalEnum, value)); // requires u8 format + + // Test null optional enum size + const null_value: OptionalEnum = null; + try std.testing.expectEqual(1, getEnumSize(OptionalEnum, null_value)); // size of null +} \ No newline at end of file diff --git a/src/msgpack.zig b/src/msgpack.zig index 0e6fb23..7020dec 100644 --- a/src/msgpack.zig +++ b/src/msgpack.zig @@ -65,6 +65,11 @@ pub const UnionAsMapOptions = @import("union.zig").UnionAsMapOptions; pub const packUnion = @import("union.zig").packUnion; pub const unpackUnion = @import("union.zig").unpackUnion; +pub const getEnumSize = @import("enum.zig").getEnumSize; +pub const getMaxEnumSize = @import("enum.zig").getMaxEnumSize; +pub const packEnum = @import("enum.zig").packEnum; +pub const unpackEnum = @import("enum.zig").unpackEnum; + pub const packAny = @import("any.zig").packAny; pub const unpackAny = @import("any.zig").unpackAny; @@ -144,6 +149,10 @@ pub fn Packer(comptime Writer: type) type { return packUnion(self.writer, @TypeOf(value), value); } + pub fn writeEnum(self: Self, value: anytype) !void { + return packEnum(self.writer, @TypeOf(value), value); + } + pub fn write(self: Self, value: anytype) !void { return packAny(self.writer, value); } @@ -232,6 +241,10 @@ pub fn Unpacker(comptime Reader: type) type { return unpackUnion(self.reader, self.allocator, T); } + pub fn readEnum(self: Self, comptime T: type) !T { + return unpackEnum(self.reader, T); + } + pub fn read(self: Self, comptime T: type) !T { return unpackAny(self.reader, self.allocator, T); } @@ -304,3 +317,60 @@ test "encode/decode" { try std.testing.expectEqualStrings("John", decoded.value.name); try std.testing.expectEqual(20, decoded.value.age); } + +test "encode/decode enum" { + const Status = enum(u8) { pending = 1, active = 2, inactive = 3 }; + const PlainEnum = enum { foo, bar, baz }; + + // Test enum(u8) + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try encode(Status.active, buffer.writer()); + + const decoded = try decodeFromSlice(Status, std.testing.allocator, buffer.items); + defer decoded.deinit(); + + try std.testing.expectEqual(Status.active, decoded.value); + } + + // Test plain enum + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try encode(PlainEnum.bar, buffer.writer()); + + const decoded = try decodeFromSlice(PlainEnum, std.testing.allocator, buffer.items); + defer decoded.deinit(); + + try std.testing.expectEqual(PlainEnum.bar, decoded.value); + } + + // Test optional enum with null + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try encode(@as(?Status, null), buffer.writer()); + + const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items); + defer decoded.deinit(); + + try std.testing.expectEqual(@as(?Status, null), decoded.value); + } + + // Test optional enum with value + { + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + + try encode(@as(?Status, .pending), buffer.writer()); + + const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items); + defer decoded.deinit(); + + try std.testing.expectEqual(@as(?Status, .pending), decoded.value); + } +}