Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/any.zig
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ const unpackStruct = @import("struct.zig").unpackStruct;
const packUnion = @import("union.zig").packUnion;
const unpackUnion = @import("union.zig").unpackUnion;

const getEnumSize = @import("enum.zig").getEnumSize;
const packEnum = @import("enum.zig").packEnum;
const unpackEnum = @import("enum.zig").unpackEnum;

inline fn isString(comptime T: type) bool {
switch (@typeInfo(T)) {
.pointer => |ptr_info| {
Expand All @@ -56,6 +60,7 @@ pub fn sizeOfPackedAny(comptime T: type, value: T) usize {
.bool => return getBoolSize(),
.int => return getIntSize(T, value),
.float => return getFloatSize(T, value),
.@"enum" => return getEnumSize(T, value),
.pointer => |ptr_info| {
if (ptr_info.size == .Slice) {
if (isString(T)) {
Expand Down Expand Up @@ -105,6 +110,7 @@ pub fn packAny(writer: anytype, value: anytype) !void {
},
.@"struct" => return packStruct(writer, T, value),
.@"union" => return packUnion(writer, T, value),
.@"enum" => return packEnum(writer, T, value),
.optional => {
if (value) |val| {
return packAny(writer, val);
Expand All @@ -125,6 +131,7 @@ pub fn unpackAny(reader: anytype, allocator: std.mem.Allocator, comptime T: type
.float => return unpackFloat(reader, T),
.@"struct" => return unpackStruct(reader, allocator, T),
.@"union" => return unpackUnion(reader, allocator, T),
.@"enum" => return unpackEnum(reader, T),
.pointer => |ptr_info| {
if (ptr_info.size == .slice) {
if (isString(T)) {
Expand Down
193 changes: 193 additions & 0 deletions src/enum.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
const std = @import("std");

const maybePackNull = @import("null.zig").maybePackNull;

const getIntSize = @import("int.zig").getIntSize;
const packInt = @import("int.zig").packInt;
const unpackInt = @import("int.zig").unpackInt;

inline fn assertEnumType(comptime T: type) type {
switch (@typeInfo(T)) {
.@"enum" => return T,
.optional => |opt_info| {
return assertEnumType(opt_info.child);
},
else => @compileError("Expected enum, got " ++ @typeName(T)),
}
}

pub fn getMaxEnumSize(comptime T: type) usize {
const Type = assertEnumType(T);
const tag_type = @typeInfo(Type).@"enum".tag_type;
return 1 + @sizeOf(tag_type);
}

pub fn getEnumSize(comptime T: type, value: T) usize {
if (@typeInfo(T) == .optional) {
if (value) |v| {
return getEnumSize(@typeInfo(T).optional.child, v);
} else {
return 1; // size of null
}
}

const tag_type = @typeInfo(T).@"enum".tag_type;
const int_value = @intFromEnum(value);
return getIntSize(tag_type, int_value);
}

pub fn packEnum(writer: anytype, comptime T: type, value_or_maybe_null: T) !void {
const Type = assertEnumType(T);
const value: Type = try maybePackNull(writer, T, value_or_maybe_null) orelse return;

const tag_type = @typeInfo(Type).@"enum".tag_type;
const int_value = @intFromEnum(value);

try packInt(writer, tag_type, int_value);
}

pub fn unpackEnum(reader: anytype, comptime T: type) !T {
const Type = assertEnumType(T);
const tag_type = @typeInfo(Type).@"enum".tag_type;

// Construct the optional tag type to match T's optionality
const OptionalTagType = if (@typeInfo(T) == .optional) ?tag_type else tag_type;

// Use unpackInt directly with the constructed optional tag type
const int_value = try unpackInt(reader, OptionalTagType);

// Handle the optional case
if (@typeInfo(T) == .optional) {
if (int_value) |value| {
return @enumFromInt(value);
} else {
return null;
}
} else {
return @enumFromInt(int_value);
}
}

test "getMaxEnumSize" {
const PlainEnum = enum { foo, bar };
const U8Enum = enum(u8) { foo = 1, bar = 2 };
const U16Enum = enum(u16) { foo, bar };

try std.testing.expectEqual(2, getMaxEnumSize(PlainEnum)); // u1 + header
try std.testing.expectEqual(2, getMaxEnumSize(U8Enum)); // u8 + header
try std.testing.expectEqual(3, getMaxEnumSize(U16Enum)); // u16 + header
}

test "getEnumSize" {
const U8Enum = enum(u8) { foo = 0, bar = 150 };

try std.testing.expectEqual(1, getEnumSize(U8Enum, .foo)); // fits in positive fixint
try std.testing.expectEqual(2, getEnumSize(U8Enum, .bar)); // requires u8 format
}

test "pack/unpack enum" {
const PlainEnum = enum { foo, bar };
const U8Enum = enum(u8) { foo = 1, bar = 2 };
const U16Enum = enum(u16) { alpha = 1000, beta = 2000 };

// Test plain enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), PlainEnum, .bar);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), PlainEnum);
try std.testing.expectEqual(PlainEnum.bar, result);
}

// Test enum(u8)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), U8Enum, .bar);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), U8Enum);
try std.testing.expectEqual(U8Enum.bar, result);
}

// Test enum(u16)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), U16Enum, .alpha);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), U16Enum);
try std.testing.expectEqual(U16Enum.alpha, result);
}
}


test "enum edge cases" {
// Test enum with explicit and auto values
const MixedEnum = enum(u8) {
first = 10,
second, // auto-assigned to 11
third = 20,
fourth, // auto-assigned to 21
};

var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), MixedEnum, .second);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), MixedEnum);
try std.testing.expectEqual(MixedEnum.second, result);
try std.testing.expectEqual(11, @intFromEnum(result));
}

test "optional enum" {
const TestEnum = enum(u8) { foo = 1, bar = 2 };
const OptionalEnum = ?TestEnum;

// Test non-null optional enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

const value: OptionalEnum = .bar;
try packEnum(buffer.writer(), OptionalEnum, value);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), OptionalEnum);
try std.testing.expectEqual(@as(OptionalEnum, .bar), result);
}

// Test null optional enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

const value: OptionalEnum = null;
try packEnum(buffer.writer(), OptionalEnum, value);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), OptionalEnum);
try std.testing.expectEqual(@as(OptionalEnum, null), result);
}
}

test "getEnumSize with optional" {
const TestEnum = enum(u8) { foo = 0, bar = 150 };
const OptionalEnum = ?TestEnum;

// Test non-null optional enum size
const value: OptionalEnum = .bar;
try std.testing.expectEqual(2, getEnumSize(OptionalEnum, value)); // requires u8 format

// Test null optional enum size
const null_value: OptionalEnum = null;
try std.testing.expectEqual(1, getEnumSize(OptionalEnum, null_value)); // size of null
}
70 changes: 70 additions & 0 deletions src/msgpack.zig
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ pub const UnionAsMapOptions = @import("union.zig").UnionAsMapOptions;
pub const packUnion = @import("union.zig").packUnion;
pub const unpackUnion = @import("union.zig").unpackUnion;

pub const getEnumSize = @import("enum.zig").getEnumSize;
pub const getMaxEnumSize = @import("enum.zig").getMaxEnumSize;
pub const packEnum = @import("enum.zig").packEnum;
pub const unpackEnum = @import("enum.zig").unpackEnum;

pub const packAny = @import("any.zig").packAny;
pub const unpackAny = @import("any.zig").unpackAny;

Expand Down Expand Up @@ -144,6 +149,10 @@ pub fn Packer(comptime Writer: type) type {
return packUnion(self.writer, @TypeOf(value), value);
}

pub fn writeEnum(self: Self, value: anytype) !void {
return packEnum(self.writer, @TypeOf(value), value);
}

pub fn write(self: Self, value: anytype) !void {
return packAny(self.writer, value);
}
Expand Down Expand Up @@ -232,6 +241,10 @@ pub fn Unpacker(comptime Reader: type) type {
return unpackUnion(self.reader, self.allocator, T);
}

pub fn readEnum(self: Self, comptime T: type) !T {
return unpackEnum(self.reader, T);
}

pub fn read(self: Self, comptime T: type) !T {
return unpackAny(self.reader, self.allocator, T);
}
Expand Down Expand Up @@ -304,3 +317,60 @@ test "encode/decode" {
try std.testing.expectEqualStrings("John", decoded.value.name);
try std.testing.expectEqual(20, decoded.value.age);
}

test "encode/decode enum" {
const Status = enum(u8) { pending = 1, active = 2, inactive = 3 };
const PlainEnum = enum { foo, bar, baz };

// Test enum(u8)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(Status.active, buffer.writer());

const decoded = try decodeFromSlice(Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(Status.active, decoded.value);
}

// Test plain enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(PlainEnum.bar, buffer.writer());

const decoded = try decodeFromSlice(PlainEnum, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(PlainEnum.bar, decoded.value);
}

// Test optional enum with null
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(@as(?Status, null), buffer.writer());

const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(@as(?Status, null), decoded.value);
}

// Test optional enum with value
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(@as(?Status, .pending), buffer.writer());

const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(@as(?Status, .pending), decoded.value);
}
}
Loading