Skip to content

Commit 4759254

Browse files
committed
Neon: Use a mask to locate the characters that need to be escaped instead of iterating through the chunk one byte/result at a time.
1 parent 1d00db9 commit 4759254

File tree

2 files changed

+81
-19
lines changed

2 files changed

+81
-19
lines changed

ext/json/ext/generator/generator.c

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ typedef struct _search_state {
114114
#ifdef ENABLE_SIMD
115115
const char *returned_from;
116116
unsigned char maybe_matches[16];
117+
118+
#ifdef HAVE_SIMD_NEON
119+
uint64_t matches_mask;
120+
const char *chunk_base;
121+
uint8_t has_matches;
122+
#endif /* HAVE_SIMD_NEON */
123+
117124
unsigned long current_match_index;
118125
unsigned long maybe_match_length;
119126
#endif /* ENABLE_SIMD */
@@ -273,15 +280,40 @@ static inline unsigned char search_escape_basic_simd_next_match(search_state *se
273280

274281
#ifdef HAVE_SIMD_NEON
275282

283+
static inline unsigned char neon_mask_next_match(search_state *search) {
284+
uint64_t mask = search->matches_mask;
285+
if (mask > 0) {
286+
uint32_t index = trailing_zeros(mask) >> 2;
287+
288+
// It is assumed escape_UTF8_char_basic will only ever increase search->ptr by at most one character.
289+
// If we want to use a similar approach for full escaping we'll need to ensure:
290+
// search->chunk_base + index >= search->ptr
291+
// However, since we know escape_UTF8_char_basic only increases search->ptr by one, if the next match
292+
// is one byte after the previous match then:
293+
// search->chunk_base + index == search->ptr
294+
search->ptr = search->chunk_base + index;
295+
mask &= mask - 1;
296+
search->matches_mask = mask;
297+
search_flush(search);
298+
return 1;
299+
}
300+
return 0;
301+
}
302+
303+
// See: https://community.arm.com/arm-community-blogs/b/servers-and-cloud-computing-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
304+
static inline uint64_t neon_match_mask(uint8x16_t matches) {
305+
const uint8x8_t res = vshrn_n_u16(vreinterpretq_u16_u8(matches), 4);
306+
const uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(res), 0);
307+
return mask & 0x8888888888888888ull;
308+
}
309+
276310
static inline uint8x16_t neon_lut_update(uint8x16_t chunk) {
277311
uint8x16_t tmp1 = vqtbl4q_u8(simd_state.neon.escape_table_basic[0], chunk);
278312
uint8x16_t tmp2 = vqtbl4q_u8(simd_state.neon.escape_table_basic[1], veorq_u8(chunk, vdupq_n_u8(0x40)));
279-
280313
uint8x16_t result = vorrq_u8(tmp1, tmp2);
281314
return result;
282315
}
283316

284-
285317
static inline unsigned char search_escape_basic_neon_advance_lut(search_state *search) {
286318
while (search->ptr+sizeof(uint8x16_t) < search->end) {
287319
uint8x16_t chunk = vld1q_u8((const unsigned char *)search->ptr);
@@ -292,11 +324,10 @@ static inline unsigned char search_escape_basic_neon_advance_lut(search_state *s
292324
continue;
293325
}
294326

295-
vst1q_u8(search->maybe_matches, result);
296-
297-
search->current_match_index = 0;
298-
search->maybe_match_length = sizeof(uint8x16_t);
299-
return search_escape_basic_simd_next_match(search);
327+
search->matches_mask = neon_match_mask(vceqq_u8(result, vdupq_n_u8(9)));
328+
search->has_matches = 1;
329+
search->chunk_base = search->ptr;
330+
return neon_mask_next_match(search);
300331
}
301332

302333
// There are fewer than 16 bytes left.
@@ -396,13 +427,11 @@ static unsigned char search_escape_basic_neon_advance_rules(search_state *search
396427
search->ptr += sizeof(uint8x16_t);
397428
continue;
398429
}
399-
400-
// It doesn't matter the value of each byte in 'maybe_matches' as long as a match is non-zero.
401-
vst1q_u8(search->maybe_matches, needs_escape);
402430

403-
search->current_match_index = 0;
404-
search->maybe_match_length = sizeof(uint8x16_t);
405-
return search_escape_basic_simd_next_match(search);
431+
search->matches_mask = neon_match_mask(needs_escape);
432+
search->has_matches = 1;
433+
search->chunk_base = search->ptr;
434+
return neon_mask_next_match(search);
406435
}
407436

408437
// There are fewer than 16 bytes left.
@@ -439,11 +468,17 @@ static unsigned char search_escape_basic_neon_advance_rules(search_state *search
439468

440469
static inline unsigned char search_escape_basic_neon(search_state *search)
441470
{
442-
if (RB_UNLIKELY(search->returned_from != NULL)) {
443-
search->current_match_index += (search->ptr - search->returned_from);
444-
search->returned_from = NULL;
445-
if (RB_UNLIKELY(search_escape_basic_simd_next_match(search))) {
446-
return 1;
471+
if (RB_UNLIKELY(search->has_matches)) {
472+
// There are more matches if search->matches_mask > 0.
473+
if (search->matches_mask > 0) {
474+
if (RB_LIKELY(neon_mask_next_match(search))) {
475+
return 1;
476+
}
477+
} else {
478+
// neon_mask_next_match will only advance search->ptr up to the last matching character.
479+
// Skip over any characters in the last chunk that occur after the last match.
480+
search->has_matches = 0;
481+
search->ptr = search->chunk_base+sizeof(uint8x16_t);
447482
}
448483
}
449484

@@ -1331,7 +1366,8 @@ static void generate_json_string(FBuffer *buffer, struct generate_json_data *dat
13311366

13321367
#ifdef ENABLE_SIMD
13331368
search.current_match_index = 0;
1334-
search.returned_from = NULL;
1369+
search.matches_mask = 0;
1370+
search.has_matches = 0;
13351371
#endif /* ENABLE_SIMD */
13361372

13371373
switch(rb_enc_str_coderange(obj)) {

ext/json/ext/generator/simd.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@ typedef enum {
88

99
#ifdef ENABLE_SIMD
1010

11+
#ifdef __clang__
12+
#if __has_builtin(__builtin_ctzll)
13+
#define HAVE_BUILTIN_CTZLL 1
14+
#else
15+
#define HAVE_BUILTIN_CTZLL 0
16+
#endif
17+
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))
18+
#define HAVE_BUILTIN_CTZLL 1
19+
#else
20+
#define HAVE_BUILTIN_CTZLL 0
21+
#endif
22+
23+
static inline uint32_t trailing_zeros(uint64_t input) {
24+
#if HAVE_BUILTIN_CTZLL
25+
return __builtin_ctzll(input);
26+
#else
27+
uint32_t trailing_zeros = 0;
28+
uint64_t temp = input;
29+
while ((temp & 1) == 0 && temp > 0) {
30+
trailing_zeros++;
31+
temp >>= 1;
32+
}
33+
return trailing_zeros;
34+
#endif
35+
}
36+
1137
#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__aarch64__) || defined(_M_ARM64)
1238
#include <arm_neon.h>
1339

0 commit comments

Comments
 (0)