From 3a769d942d172e386f7735bd6b889ebde0dd7389 Mon Sep 17 00:00:00 2001 From: Thomas Jungblut Date: Fri, 9 May 2025 22:10:15 +0200 Subject: [PATCH] add simd for faster magic number searches --- recordio/mmap_reader.go | 41 ++++++++ recordio/simd/magic_number_search.go | 38 ++++++++ recordio/simd/magic_number_search_test.go | 40 ++++++++ recordio/simd/search.c | 114 ++++++++++++++++++++++ recordio/simd/search.h | 12 +++ 5 files changed, 245 insertions(+) create mode 100644 recordio/simd/magic_number_search.go create mode 100644 recordio/simd/magic_number_search_test.go create mode 100644 recordio/simd/search.c create mode 100644 recordio/simd/search.h diff --git a/recordio/mmap_reader.go b/recordio/mmap_reader.go index fdbfa3a..3815250 100644 --- a/recordio/mmap_reader.go +++ b/recordio/mmap_reader.go @@ -5,7 +5,10 @@ import ( "bytes" "errors" "fmt" + "github.com/thomasjungblut/go-sstables/recordio/simd" "io" + "reflect" + "unsafe" "golang.org/x/exp/mmap" @@ -20,6 +23,9 @@ type MMapReader struct { bufferPool *pool.Pool path string seekLen int + + simdAvailable bool + mmapReaderSlice []byte } func (r *MMapReader) Open() error { @@ -48,6 +54,14 @@ func (r *MMapReader) Open() error { r.header = header r.bufferPool = pool.NewPool(1024, 20) r.open = true + r.simdAvailable = simd.AVXSupported() + + v := reflect.ValueOf(r.mmapReader).Elem() + dataField := v.FieldByName("data") + dataPtr := unsafe.Pointer(dataField.UnsafeAddr()) + dataSlice := *(*[]byte)(dataPtr) + r.mmapReaderSlice = unsafe.Slice(&dataSlice[0], len(dataSlice)) + return nil } @@ -63,6 +77,10 @@ func (r *MMapReader) SeekNext(offset uint64) (uint64, []byte, error) { return 0, nil, fmt.Errorf("unsupported on files with version lower than v2") } + if r.simdAvailable { + return r.seekNextVectorized(offset) + } + headerBufPooled := r.bufferPool.Get(r.seekLen) defer r.bufferPool.Put(headerBufPooled) @@ -127,6 +145,29 @@ func (r *MMapReader) SeekNext(offset uint64) (uint64, []byte, error) { } } +func (r *MMapReader) seekNextVectorized(offset uint64) (uint64, []byte, error) { + i := offset + for { + ofx := simd.FindMagicNumber(r.mmapReaderSlice, int(i)) + if ofx < 0 { + return 0, nil, io.EOF + } + + record, err := r.ReadNextAt(uint64(ofx)) + if err != nil { + if errors.Is(err, HeaderChecksumMismatchErr) || errors.Is(err, MagicNumberMismatchErr) || errors.Is(err, io.EOF) { + // try to seek again, the record couldn't be read fully + i = uint64(ofx + 1) + continue + } + + return 0, nil, err + } else { + return uint64(ofx), record, nil + } + } +} + func (r *MMapReader) ReadNextAt(offset uint64) ([]byte, error) { if !r.open || r.closed { return nil, fmt.Errorf("reader at '%s' was either not opened yet or is closed already", r.path) diff --git a/recordio/simd/magic_number_search.go b/recordio/simd/magic_number_search.go new file mode 100644 index 0000000..3fdfefb --- /dev/null +++ b/recordio/simd/magic_number_search.go @@ -0,0 +1,38 @@ +package simd + +/* +#cgo CFLAGS: -mavx2 +#include "search.h" +*/ +import "C" +import "unsafe" + +func AVXSupported() bool { + result := C.cpu_supports_avx2() + return int(result) == 1 +} + +func FindFirstMagicNumber(data []byte) int { + if len(data) < 3 { + return -1 + } + ptr := (*C.uchar)(unsafe.Pointer(&data[0])) + offset := C.size_t(0) + length := C.size_t(len(data)) + result := C.find_magic_numbers(ptr, offset, length) + return int(result) +} + +func FindMagicNumber(data []byte, off int) int { + if len(data) < 3 { + return -1 + } + if off >= len(data) || off < 0 { + return -1 + } + ptr := (*C.uchar)(unsafe.Pointer(&data[0])) + offset := C.size_t(off) + length := C.size_t(len(data)) + result := C.find_magic_numbers(ptr, offset, length) + return int(result) +} diff --git a/recordio/simd/magic_number_search_test.go b/recordio/simd/magic_number_search_test.go new file mode 100644 index 0000000..d95fb4f --- /dev/null +++ b/recordio/simd/magic_number_search_test.go @@ -0,0 +1,40 @@ +package simd + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestMagicNumberSearchHappyPath(t *testing.T) { + if !AVXSupported() { + t.Skip() + } + + data := make([]byte, 10000) + + data[10000-300] = 145 + data[10000-299] = 141 + data[10000-298] = 76 + + data[10000-3] = 145 + data[10000-2] = 141 + data[10000-1] = 76 + + index := FindFirstMagicNumber(data) + require.Equal(t, 10000-300, index) + index = FindMagicNumber(data, 0) + require.Equal(t, 10000-300, index) + + ix := FindFirstMagicNumber(data[9701:]) + require.Equal(t, 296, ix) + ix = FindMagicNumber(data, 9701) + require.Equal(t, 10000-3, ix) +} + +func TestMagicNumberSearchBoundary(t *testing.T) { + require.Equal(t, -1, FindFirstMagicNumber([]byte{0, 1})) + require.Equal(t, -1, FindMagicNumber([]byte{0, 1}, 0)) + require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, 3)) + require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, 4)) + require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, -1)) +} diff --git a/recordio/simd/search.c b/recordio/simd/search.c new file mode 100644 index 0000000..3449a9b --- /dev/null +++ b/recordio/simd/search.c @@ -0,0 +1,114 @@ +#include "search.h" +#include +#include +#include + +static const unsigned char pattern[] = {145, 141, 76}; + +// Returns 1 if AVX2 is available, 0 otherwise +int cpu_supports_avx2() { + unsigned int eax, ebx, ecx, edx; + + // First, check if CPUID leaf 7 is supported + if (__get_cpuid_max(0, 0) < 7) + return 0; + + // Call CPUID leaf 7, subleaf 0 + __cpuid_count(7, 0, eax, ebx, ecx, edx); + + // Bit 5 of EBX in CPUID leaf 7 indicates AVX2 support + return (ebx & (1 << 5)) != 0; +} + +// Returns 1 if AVX512 is available, 0 otherwise +int cpu_supports_avx512() { + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, 0) < 7) return 0; + __cpuid_count(7, 0, eax, ebx, ecx, edx); + return (ebx & (1 << 16)) != 0; // AVX-512F +} + +int find_magic_numbers(const unsigned char* data, size_t off, size_t len) { + if (len < 3) return -1; + if (off >= len) return -1; + + size_t i = off; + size_t end = len - 2; + + // process 32 bytes per loop using AVX2 + for (; i + 32 <= end; i += 1) { + __m256i d0 = _mm256_loadu_si256((const __m256i*)(data + i)); + __m256i d1 = _mm256_loadu_si256((const __m256i*)(data + i + 1)); + __m256i d2 = _mm256_loadu_si256((const __m256i*)(data + i + 2)); + + __m256i p0 = _mm256_set1_epi8(pattern[0]); + __m256i p1 = _mm256_set1_epi8(pattern[1]); + __m256i p2 = _mm256_set1_epi8(pattern[2]); + + __m256i m0 = _mm256_cmpeq_epi8(d0, p0); + __m256i m1 = _mm256_cmpeq_epi8(d1, p1); + __m256i m2 = _mm256_cmpeq_epi8(d2, p2); + + __m256i mask = _mm256_and_si256(_mm256_and_si256(m0, m1), m2); + int matchmask = _mm256_movemask_epi8(mask); + + if (matchmask) { + // return the first match index + return i + __builtin_ctz(matchmask); + } + } + + // Fallback naive scan for remaining bytes + for (; i < end; i++) { + if (data[i] == pattern[0] && + data[i+1] == pattern[1] && + data[i+2] == pattern[2]) { + return i; + } + } + + return -1; +} + +/* +TODO(thomas): we would need to split the cgo flags and compilation units to match + +int find_magic_numbers_avx512(const unsigned char* data, size_t off, size_t len) { + if (len < 3) return -1; + if (off >= len) return -1; + + size_t i = off; + size_t end = len - 2; + + // process 64 bytes per loop using AVX512 + for (size_t i = 0; i + 64 <= end; i++) { + __m512i d0 = _mm512_loadu_si512((const void*)(data + i)); + __m512i d1 = _mm512_loadu_si512((const void*)(data + i + 1)); + __m512i d2 = _mm512_loadu_si512((const void*)(data + i + 2)); + + __m512i p0 = _mm512_set1_epi8(pattern[0]); + __m512i p1 = _mm512_set1_epi8(pattern[1]); + __m512i p2 = _mm512_set1_epi8(pattern[2]); + + __mmask64 m0 = _mm512_cmpeq_epi8_mask(d0, p0); + __mmask64 m1 = _mm512_cmpeq_epi8_mask(d1, p1); + __mmask64 m2 = _mm512_cmpeq_epi8_mask(d2, p2); + + __mmask64 m = m0 & m1 & m2; + if (m) { + return i + __builtin_ctzll(m); + } + } + + // Fallback naive scan for remaining bytes + for (; i < end; i++) { + if (data[i] == pattern[0] && + data[i+1] == pattern[1] && + data[i+2] == pattern[2]) { + return i; + } + } + + return -1; +} +*/ diff --git a/recordio/simd/search.h b/recordio/simd/search.h new file mode 100644 index 0000000..af32495 --- /dev/null +++ b/recordio/simd/search.h @@ -0,0 +1,12 @@ +#ifndef SEARCH_H +#define SEARCH_H + +#include + +int cpu_supports_avx2(); + +int cpu_supports_avx512(); + +int find_magic_numbers(const unsigned char* data, size_t off, size_t len); + +#endif \ No newline at end of file