Skip to content

Enhancing indexOfDiff efficiency in large input slices #24097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
125 changes: 121 additions & 4 deletions lib/std/mem.zig
Original file line number Diff line number Diff line change
Expand Up @@ -787,11 +787,14 @@ fn eqlBytes(a: []const u8, b: []const u8) bool {
/// Compares two slices and returns the index of the first inequality.
/// Returns null if the slices are equal.
pub fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
if (!std.debug.inValgrind() // https://github.com/ziglang/zig/issues/17717
and backend_supports_vectors and !@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T))
return if (indexOfDiffBytes(sliceAsBytes(a), sliceAsBytes(b))) |index| index / @sizeOf(T) else null;

const shortest = @min(a.len, b.len);
if (a.ptr == b.ptr)
return if (a.len == b.len) null else shortest;
var index: usize = 0;
while (index < shortest) : (index += 1) if (a[index] != b[index]) return index;
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;

for (0..shortest) |index| if (a[index] != b[index]) return index;
return if (a.len == b.len) null else shortest;
}

Expand All @@ -801,6 +804,120 @@ test indexOfDiff {
try testing.expectEqual(indexOfDiff(u8, "one", "one two"), 3);
try testing.expectEqual(indexOfDiff(u8, "one twx", "one two"), 6);
try testing.expectEqual(indexOfDiff(u8, "xne", "one"), 0);
try testing.expectEqual(indexOfDiff(u16, &.{ 0x4e00, 0x4e8c, 0x4e09, 0x56db }, &.{ 0x4e00, 0x4e8c, 0x4e09 }), 3);
try testing.expectEqual(indexOfDiff(u16, &.{ 0x96f6, 0x4e8c, 0x4e09, 0x56db }, &.{ 0x4e00, 0x4e8c, 0x4e09, 0x56db }), 0);
try testing.expectEqual(indexOfDiff(f64, &.{ 0x8000000000000000, 0x0000000000000000 }, &.{ 0x0000000000000000, 0x0000000000000000 }), 0);
try testing.expectEqual(indexOfDiff(u64, &.{ 0xaaaaaaaaaaaaaaaa, 0xaaaaaaaaaaaabbbb }, &.{ 0xaaaaaaaaaaaaaaaa, 0xaaaaaaaaaaaacccc }), 1);
comptime {
try testing.expectEqual(indexOfDiff(type, &.{ bool, f32 }, &.{ bool, f32 }), null);
try testing.expectEqual(indexOfDiff(type, &.{ bool, f32 }, &.{ f32, bool }), 0);
try testing.expectEqual(indexOfDiff(type, &.{ bool, f32 }, &.{bool}), 1);
try testing.expectEqual(indexOfDiff(comptime_int, &.{ 1, 2, 3 }, &.{ 1, 2, 3 }), null);
try testing.expectEqual(indexOfDiff(comptime_int, &.{ 1, 2, 3 }, &.{ 1, 2, 4 }), 2);
try testing.expectEqual(indexOfDiff(comptime_int, &.{1}, &.{ 1, 2 }), 1);
}
try testing.expectEqual(indexOfDiff(void, &.{ {}, {} }, &.{ {}, {} }), null);
try testing.expectEqual(indexOfDiff(void, &.{{}}, &.{ {}, {} }), 1);
}

/// std.mem.indexOfDiff heavily optimized for slices of bytes.
fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
const shortest = @min(a.len, b.len);
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;

const swar_thr = @sizeOf(usize) * 2;
const max_vec_size = std.simd.suggestVectorLength(u8) orelse 0;
const unroll_factor = 4;
// Context used to generate corresponding scanning strategies (SWAR/SIMD) at compile time
const Ctx = struct {
fn Scan(vec_size: comptime_int) type {
return if (vec_size != 0) struct { // SIMD path
const size = vec_size;
const Chunk = @Vector(size, u8);
const Mask = @Vector(size, bool);
inline fn load(src: []const u8) Chunk {
return @bitCast(src[0..size].*);
}
inline fn toMask(lhs: Chunk, rhs: Chunk) Mask {
return lhs != rhs;
}
inline fn firstTrue(mask: Mask) ?usize {
return if (std.simd.firstTrue(mask)) |offset| @intCast(offset) else null;
}
} else struct { // SWAR path
const size = @sizeOf(usize);
const Chunk = usize;
const Mask = usize;
inline fn load(src: []const u8) Chunk {
return @bitCast(src[0..size].*);
}
inline fn toMask(lhs: Chunk, rhs: Chunk) Mask {
return lhs ^ rhs;
}
inline fn hasDiff(mask: Mask) bool {
return mask != 0;
}
inline fn firstTrue(mask: Mask) ?usize {
// Endian-aware
const offset = if (native_endian == .little) @ctz(mask) else @clz(mask);
return if (offset == @bitSizeOf(Mask)) null else offset / 8;
}
};
}
};
// Samll slices (0, @sizeOf(usize) * 2]
if (shortest <= swar_thr) {
const Scan = Ctx.Scan(0);
// (0, @sizeOf(usize))
if (shortest < Scan.size) {
for (0..shortest) |index| if (a[index] != b[index]) return index;
return if (a.len == b.len) null else shortest;
}
// [@sizeOf(usize), @sizeOf(usize) * 2]
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
if (Scan.firstTrue(mask)) |offset| return index + offset;
}
return if (a.len == b.len) null else shortest;
}
// Medium slices (@sizeOf(usize) * 2, max_vec_size)
if (shortest < max_vec_size) {
// Finding the appropriate vector length through doubling method
comptime var cur_vec_size = swar_thr;
inline while (cur_vec_size < max_vec_size) : (cur_vec_size *= 2) {
if (cur_vec_size < shortest and shortest <= cur_vec_size * 2) {
const Scan = Ctx.Scan(cur_vec_size);
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
if (Scan.firstTrue(mask)) |offset| return index + offset;
}
return if (a.len == b.len) null else shortest;
}
}
}
// Large slices [max_vec_size, +∞)
const Scan = Ctx.Scan(max_vec_size);

var index: usize = 0;
// Main unrolled loop
while (index + Scan.size * unroll_factor <= shortest) : (index += Scan.size * unroll_factor) {
inline for (0..unroll_factor) |i| {
const mask = Scan.toMask(Scan.load(a[index + Scan.size * i ..]), Scan.load(b[index + Scan.size * i ..]));
if (Scan.firstTrue(mask)) |offset| return index + Scan.size * i + offset;
}
}
// Residual iterations
while (index + Scan.size <= shortest) : (index += Scan.size) {
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
if (Scan.firstTrue(mask)) |offset| return index + offset;
}
// Final overlapping check
if (index < shortest) {
const mask = Scan.toMask(Scan.load(a[shortest - Scan.size ..]), Scan.load(b[shortest - Scan.size ..]));
if (Scan.firstTrue(mask)) |offset| return shortest - Scan.size + offset;
}

return if (a.len == b.len) null else shortest;
}

/// Takes a sentinel-terminated pointer and returns a slice preserving pointer attributes.
Expand Down