Skip to content

Commit dcf6ff4

Browse files
committed
Add SIMD support detection for indexOfDiff
1 parent 6fc5608 commit dcf6ff4

File tree

1 file changed

+63
-38
lines changed

1 file changed

+63
-38
lines changed

lib/std/mem.zig

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,8 @@ fn eqlBytes(a: []const u8, b: []const u8) bool {
787787
/// Compares two slices and returns the index of the first inequality.
788788
/// Returns null if the slices are equal.
789789
pub fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
790-
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T))
791-
return if (indexOfDiffBytes(sliceAsBytes(a), sliceAsBytes(b))) |index| index / @sizeOf(T) else return null;
790+
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T) and eqlBytes_allowed)
791+
return if (indexOfDiffBytes(sliceAsBytes(a), sliceAsBytes(b))) |index| index / @sizeOf(T) else null;
792792

793793
const shortest = @min(a.len, b.len);
794794
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;
@@ -821,45 +821,66 @@ test indexOfDiff {
821821

822822
/// std.mem.indexOfDiff heavily optimized for slices of bytes.
823823
fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
824+
comptime assert(eqlBytes_allowed);
825+
824826
const shortest = @min(a.len, b.len);
825-
const vec_len = std.simd.suggestVectorLength(u8) orelse 0;
826827
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;
827828

828-
if (shortest < @sizeOf(usize)) {
829-
for (0..shortest) |index| if (a[index] != b[index]) return index;
830-
}
831-
// Use SWAR when the slice is small or SIMD is not supported
832-
else if (shortest < 16 or vec_len == 0) {
833-
var index: usize = 0;
834-
while (index + @sizeOf(usize) <= shortest) : (index += @sizeOf(usize)) {
835-
const a_chunk: usize = @bitCast(a[index..][0..@sizeOf(usize)].*);
836-
const b_chunk: usize = @bitCast(b[index..][0..@sizeOf(usize)].*);
837-
const diff = a_chunk ^ b_chunk;
838-
if (diff != 0) {
839-
const offset = @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
840-
return index + offset;
829+
if (shortest < 16) {
830+
if (shortest < @sizeOf(usize)) {
831+
for (0..shortest) |index| if (a[index] != b[index]) return index;
832+
} else {
833+
var index: usize = 0;
834+
while (index + @sizeOf(usize) <= shortest) : (index += @sizeOf(usize)) {
835+
const a_chunk: usize = @bitCast(a[index..][0..@sizeOf(usize)].*);
836+
const b_chunk: usize = @bitCast(b[index..][0..@sizeOf(usize)].*);
837+
const diff = a_chunk ^ b_chunk;
838+
if (diff != 0)
839+
return index + @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
841840
}
842-
}
843-
if (index < shortest) {
844-
const a_chunk: usize = @bitCast(a[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
845-
const b_chunk: usize = @bitCast(b[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
846-
const diff = a_chunk ^ b_chunk;
847-
if (diff != 0) {
848-
const offset = @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
849-
return shortest - @sizeOf(usize) + offset;
841+
if (index < shortest) {
842+
const a_chunk: usize = @bitCast(a[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
843+
const b_chunk: usize = @bitCast(b[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
844+
const diff = a_chunk ^ b_chunk;
845+
if (diff != 0)
846+
return shortest - @sizeOf(usize) + @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
850847
}
851848
}
849+
return if (a.len == b.len) null else shortest;
852850
}
851+
852+
const Scan = if (std.simd.suggestVectorLength(u8)) |vec_len| struct {
853+
const size = vec_len;
854+
855+
pub inline fn isNotZero(cur_size: comptime_int, mask: @Vector(cur_size, bool)) bool {
856+
return @reduce(.Or, mask);
857+
}
858+
859+
pub inline fn firstTrue(cur_size: comptime_int, mask: @Vector(cur_size, bool)) usize {
860+
return std.simd.firstTrue(mask).?;
861+
}
862+
} else struct {
863+
const size = @sizeOf(usize);
864+
865+
pub inline fn isNotZero(_: comptime_int, mask: usize) bool {
866+
return mask != 0;
867+
}
868+
pub inline fn firstTrue(_: comptime_int, mask: usize) usize {
869+
return @divFloor(if (native_endian == .little) @ctz(mask) else @clz(mask), 8);
870+
}
871+
};
872+
853873
// When the slice is smaller than the max vector length, reselect an appropriate vector length.
854-
else if (shortest < vec_len) {
874+
if (shortest < Scan.size) {
855875
comptime var new_vec_len = 16;
856-
inline while (new_vec_len < vec_len) : (new_vec_len *= 2) {
876+
inline while (new_vec_len < Scan.size) : (new_vec_len *= 2) {
857877
if (new_vec_len < shortest and 2 * new_vec_len >= shortest) {
858878
inline for ([_]usize{ 0, shortest - new_vec_len }) |index| {
859879
const a_chunk: @Vector(new_vec_len, u8) = @bitCast(a[index..][0..new_vec_len].*);
860880
const b_chunk: @Vector(new_vec_len, u8) = @bitCast(b[index..][0..new_vec_len].*);
861881
const diff = a_chunk != b_chunk;
862-
if (@reduce(.Or, diff)) return index + std.simd.firstTrue(diff).?;
882+
if (Scan.isNotZero(new_vec_len, diff))
883+
return index + Scan.firstTrue(new_vec_len, diff);
863884
}
864885
break;
865886
}
@@ -869,25 +890,29 @@ fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
869890
else {
870891
var index: usize = 0;
871892
const unroll_factor = 4;
872-
while (index + vec_len * unroll_factor <= shortest) : (index += vec_len * unroll_factor) {
893+
while (index + Scan.size * unroll_factor <= shortest) : (index += Scan.size * unroll_factor) {
873894
inline for (0..unroll_factor) |i| {
874-
const a_chunk: @Vector(vec_len, u8) = @bitCast(a[index + vec_len * i ..][0..vec_len].*);
875-
const b_chunk: @Vector(vec_len, u8) = @bitCast(b[index + vec_len * i ..][0..vec_len].*);
895+
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[index + Scan.size * i ..][0..Scan.size].*);
896+
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[index + Scan.size * i ..][0..Scan.size].*);
876897
const diff = a_chunk != b_chunk;
877-
if (@reduce(.Or, diff)) return index + vec_len * i + std.simd.firstTrue(diff).?;
898+
if (Scan.isNotZero(Scan.size, diff))
899+
return index + Scan.size * i + Scan.firstTrue(Scan.size, diff);
878900
}
879901
}
880-
while (index + vec_len <= shortest) : (index += vec_len) {
881-
const a_chunk: @Vector(vec_len, u8) = @bitCast(a[index..][0..vec_len].*);
882-
const b_chunk: @Vector(vec_len, u8) = @bitCast(b[index..][0..vec_len].*);
902+
while (index + Scan.size <= shortest) : (index += Scan.size) {
903+
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[index..][0..Scan.size].*);
904+
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[index..][0..Scan.size].*);
883905
const diff = a_chunk != b_chunk;
884-
if (@reduce(.Or, diff)) return index + std.simd.firstTrue(diff).?;
906+
if (Scan.isNotZero(Scan.size, diff))
907+
return index + Scan.firstTrue(Scan.size, diff);
885908
}
909+
886910
if (index < shortest) {
887-
const a_chunk: @Vector(vec_len, u8) = @bitCast(a[shortest - vec_len ..][0..vec_len].*);
888-
const b_chunk: @Vector(vec_len, u8) = @bitCast(b[shortest - vec_len ..][0..vec_len].*);
911+
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[shortest - Scan.size ..][0..Scan.size].*);
912+
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[shortest - Scan.size ..][0..Scan.size].*);
889913
const diff = a_chunk != b_chunk;
890-
if (@reduce(.Or, diff)) return shortest - vec_len + std.simd.firstTrue(diff).?;
914+
if (Scan.isNotZero(Scan.size, diff))
915+
return shortest - Scan.size + Scan.firstTrue(Scan.size, diff);
891916
}
892917
}
893918

0 commit comments

Comments
 (0)