Skip to content

Commit ed77dfe

Browse files
committed
Fix and refactor indexOfDiff
1 parent f0f0733 commit ed77dfe

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

lib/std/mem.zig

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,8 @@ fn eqlBytes(a: []const u8, b: []const u8) bool {
787787
/// Compares two slices and returns the index of the first inequality.
788788
/// Returns null if the slices are equal.
789789
pub fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
790-
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T))
790+
if (!std.debug.inValgrind() // https://github.com/ziglang/zig/issues/17717
791+
and backend_supports_vectors and !@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T))
791792
return if (indexOfDiffBytes(sliceAsBytes(a), sliceAsBytes(b))) |index| index / @sizeOf(T) else null;
792793

793794
const shortest = @min(a.len, b.len);
@@ -833,18 +834,15 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
833834
return if (vec_size != 0) struct { // SIMD path
834835
const size = vec_size;
835836
const Chunk = @Vector(size, u8);
836-
const Mask = @Type(.{ .int = .{ .bits = size, .signedness = .unsigned } });
837+
const Mask = @Vector(size, bool);
837838
inline fn load(src: []const u8) Chunk {
838839
return @bitCast(src[0..size].*);
839840
}
840841
inline fn toMask(lhs: Chunk, rhs: Chunk) Mask {
841-
return @bitCast(lhs != rhs);
842+
return lhs != rhs;
842843
}
843-
inline fn hasDiff(mask: Mask) bool {
844-
return mask != 0;
845-
}
846-
inline fn firstDiff(mask: Mask) usize {
847-
return @ctz(mask);
844+
inline fn firstTrue(mask: Mask) ?usize {
845+
return if (std.simd.firstTrue(mask)) |offset| @intCast(offset) else null;
848846
}
849847
} else struct { // SWAR path
850848
const size = @sizeOf(usize);
@@ -859,9 +857,10 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
859857
inline fn hasDiff(mask: Mask) bool {
860858
return mask != 0;
861859
}
862-
inline fn firstDiff(mask: Mask) usize {
860+
inline fn firstTrue(mask: Mask) ?usize {
863861
// Endian-aware
864-
return (if (native_endian == .little) @ctz(mask) else @clz(mask)) / 8;
862+
const offset = if (native_endian == .little) @ctz(mask) else @clz(mask);
863+
return if (offset == @bitSizeOf(Mask)) null else offset / 8;
865864
}
866865
};
867866
}
@@ -877,7 +876,7 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
877876
// [@sizeOf(usize), @sizeOf(usize) * 2]
878877
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
879878
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
880-
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
879+
if (Scan.firstTrue(mask)) |offset| return index + offset;
881880
}
882881
return if (a.len == b.len) null else shortest;
883882
}
@@ -890,7 +889,7 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
890889
const Scan = Ctx.Scan(cur_vec_size);
891890
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
892891
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
893-
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
892+
if (Scan.firstTrue(mask)) |offset| return index + offset;
894893
}
895894
return if (a.len == b.len) null else shortest;
896895
}
@@ -904,18 +903,18 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
904903
while (index + Scan.size * unroll_factor <= shortest) : (index += Scan.size * unroll_factor) {
905904
inline for (0..unroll_factor) |i| {
906905
const mask = Scan.toMask(Scan.load(a[index + Scan.size * i ..]), Scan.load(b[index + Scan.size * i ..]));
907-
if (Scan.hasDiff(mask)) return index + Scan.size * i + Scan.firstDiff(mask);
906+
if (Scan.firstTrue(mask)) |offset| return index + Scan.size * i + offset;
908907
}
909908
}
910909
// Residual iterations
911910
while (index + Scan.size <= shortest) : (index += Scan.size) {
912911
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
913-
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
912+
if (Scan.firstTrue(mask)) |offset| return index + offset;
914913
}
915914
// Final overlapping check
916915
if (index < shortest) {
917916
const mask = Scan.toMask(Scan.load(a[shortest - Scan.size ..]), Scan.load(b[shortest - Scan.size ..]));
918-
if (Scan.hasDiff(mask)) return shortest - Scan.size + Scan.firstDiff(mask);
917+
if (Scan.firstTrue(mask)) |offset| return shortest - Scan.size + offset;
919918
}
920919

921920
return if (a.len == b.len) null else shortest;

0 commit comments

Comments
 (0)