Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/any.zig
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ const unpackStruct = @import("struct.zig").unpackStruct;
const packUnion = @import("union.zig").packUnion;
const unpackUnion = @import("union.zig").unpackUnion;

const getEnumSize = @import("enum.zig").getEnumSize;
const packEnum = @import("enum.zig").packEnum;
const unpackEnum = @import("enum.zig").unpackEnum;

inline fn isString(comptime T: type) bool {
switch (@typeInfo(T)) {
.pointer => |ptr_info| {
Expand All @@ -56,6 +60,7 @@ pub fn sizeOfPackedAny(comptime T: type, value: T) usize {
.bool => return getBoolSize(),
.int => return getIntSize(T, value),
.float => return getFloatSize(T, value),
.@"enum" => return getEnumSize(T, value),
.pointer => |ptr_info| {
if (ptr_info.size == .Slice) {
if (isString(T)) {
Expand Down Expand Up @@ -105,6 +110,7 @@ pub fn packAny(writer: anytype, value: anytype) !void {
},
.@"struct" => return packStruct(writer, T, value),
.@"union" => return packUnion(writer, T, value),
.@"enum" => return packEnum(writer, T, value),
.optional => {
if (value) |val| {
return packAny(writer, val);
Expand All @@ -125,6 +131,7 @@ pub fn unpackAny(reader: anytype, allocator: std.mem.Allocator, comptime T: type
.float => return unpackFloat(reader, T),
.@"struct" => return unpackStruct(reader, allocator, T),
.@"union" => return unpackUnion(reader, allocator, T),
.@"enum" => return unpackEnum(reader, T),
.pointer => |ptr_info| {
if (ptr_info.size == .slice) {
if (isString(T)) {
Expand Down
130 changes: 130 additions & 0 deletions src/enum.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
const std = @import("std");
const hdrs = @import("headers.zig");

const NonOptional = @import("utils.zig").NonOptional;
const maybePackNull = @import("null.zig").maybePackNull;
const maybeUnpackNull = @import("null.zig").maybeUnpackNull;

const getIntSize = @import("int.zig").getIntSize;
const packInt = @import("int.zig").packInt;
const unpackInt = @import("int.zig").unpackInt;

inline fn assertEnumType(comptime T: type) type {
switch (@typeInfo(T)) {
.@"enum" => return T,
.optional => |opt_info| {
return assertEnumType(opt_info.child);
},
else => @compileError("Expected enum, got " ++ @typeName(T)),
}
}

pub fn getMaxEnumSize(comptime T: type) usize {
const Type = assertEnumType(T);
const tag_type = @typeInfo(Type).@"enum".tag_type;
return 1 + @sizeOf(tag_type);
}

pub fn getEnumSize(comptime T: type, value: T) usize {
const Type = assertEnumType(T);
const tag_type = @typeInfo(Type).@"enum".tag_type;
const int_value = @intFromEnum(value);
return getIntSize(tag_type, int_value);
}

pub fn packEnum(writer: anytype, comptime T: type, value_or_maybe_null: T) !void {
const Type = assertEnumType(T);
const value: Type = try maybePackNull(writer, T, value_or_maybe_null) orelse return;

const tag_type = @typeInfo(Type).@"enum".tag_type;
const int_value = @intFromEnum(value);

try packInt(writer, tag_type, int_value);
}

pub fn unpackEnum(reader: anytype, comptime T: type) !T {
const Type = assertEnumType(T);
const tag_type = @typeInfo(Type).@"enum".tag_type;
const int_value = try unpackInt(reader, tag_type);
return @enumFromInt(int_value);
}

test "getMaxEnumSize" {
const PlainEnum = enum { foo, bar };
const U8Enum = enum(u8) { foo = 1, bar = 2 };
const U16Enum = enum(u16) { foo, bar };

try std.testing.expectEqual(2, getMaxEnumSize(PlainEnum)); // u1 + header
try std.testing.expectEqual(2, getMaxEnumSize(U8Enum)); // u8 + header
try std.testing.expectEqual(3, getMaxEnumSize(U16Enum)); // u16 + header
}

test "getEnumSize" {
const U8Enum = enum(u8) { foo = 0, bar = 150 };

try std.testing.expectEqual(1, getEnumSize(U8Enum, .foo)); // fits in positive fixint
try std.testing.expectEqual(2, getEnumSize(U8Enum, .bar)); // requires u8 format
}

test "pack/unpack enum" {
const PlainEnum = enum { foo, bar };
const U8Enum = enum(u8) { foo = 1, bar = 2 };
const U16Enum = enum(u16) { alpha = 1000, beta = 2000 };

// Test plain enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), PlainEnum, .bar);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), PlainEnum);
try std.testing.expectEqual(PlainEnum.bar, result);
}

// Test enum(u8)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), U8Enum, .bar);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), U8Enum);
try std.testing.expectEqual(U8Enum.bar, result);
}

// Test enum(u16)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), U16Enum, .alpha);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), U16Enum);
try std.testing.expectEqual(U16Enum.alpha, result);
}
}


test "enum edge cases" {
// Test enum with explicit and auto values
const MixedEnum = enum(u8) {
first = 10,
second, // auto-assigned to 11
third = 20,
fourth, // auto-assigned to 21
};

var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try packEnum(buffer.writer(), MixedEnum, .second);

var stream = std.io.fixedBufferStream(buffer.items);
const result = try unpackEnum(stream.reader(), MixedEnum);
try std.testing.expectEqual(MixedEnum.second, result);
try std.testing.expectEqual(11, @intFromEnum(result));
}
70 changes: 70 additions & 0 deletions src/msgpack.zig
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ pub const UnionAsMapOptions = @import("union.zig").UnionAsMapOptions;
pub const packUnion = @import("union.zig").packUnion;
pub const unpackUnion = @import("union.zig").unpackUnion;

pub const getEnumSize = @import("enum.zig").getEnumSize;
pub const getMaxEnumSize = @import("enum.zig").getMaxEnumSize;
pub const packEnum = @import("enum.zig").packEnum;
pub const unpackEnum = @import("enum.zig").unpackEnum;

pub const packAny = @import("any.zig").packAny;
pub const unpackAny = @import("any.zig").unpackAny;

Expand Down Expand Up @@ -144,6 +149,10 @@ pub fn Packer(comptime Writer: type) type {
return packUnion(self.writer, @TypeOf(value), value);
}

pub fn writeEnum(self: Self, value: anytype) !void {
return packEnum(self.writer, @TypeOf(value), value);
}

pub fn write(self: Self, value: anytype) !void {
return packAny(self.writer, value);
}
Expand Down Expand Up @@ -232,6 +241,10 @@ pub fn Unpacker(comptime Reader: type) type {
return unpackUnion(self.reader, self.allocator, T);
}

pub fn readEnum(self: Self, comptime T: type) !T {
return unpackEnum(self.reader, T);
}

pub fn read(self: Self, comptime T: type) !T {
return unpackAny(self.reader, self.allocator, T);
}
Expand Down Expand Up @@ -304,3 +317,60 @@ test "encode/decode" {
try std.testing.expectEqualStrings("John", decoded.value.name);
try std.testing.expectEqual(20, decoded.value.age);
}

test "encode/decode enum" {
const Status = enum(u8) { pending = 1, active = 2, inactive = 3 };
const PlainEnum = enum { foo, bar, baz };

// Test enum(u8)
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(Status.active, buffer.writer());

const decoded = try decodeFromSlice(Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(Status.active, decoded.value);
}

// Test plain enum
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(PlainEnum.bar, buffer.writer());

const decoded = try decodeFromSlice(PlainEnum, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(PlainEnum.bar, decoded.value);
}

// Test optional enum with null
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(@as(?Status, null), buffer.writer());

const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(@as(?Status, null), decoded.value);
}

// Test optional enum with value
{
var buffer = std.ArrayList(u8).init(std.testing.allocator);
defer buffer.deinit();

try encode(@as(?Status, .pending), buffer.writer());

const decoded = try decodeFromSlice(?Status, std.testing.allocator, buffer.items);
defer decoded.deinit();

try std.testing.expectEqual(@as(?Status, .pending), decoded.value);
}
}