Consider wider algorithms to reduce branching
lemire opened this issue · 1 comments
lemire commented
Something like this... using 64-bit masks... which can reduce the number of branches quite a bit...
size_t avx2_strstr_anysize64(const char* s, size_t n, const char* needle, size_t k) {
assert(k > 0);
assert(n > 0);
const __m256i first = _mm256_set1_epi8(needle[0]);
const __m256i last = _mm256_set1_epi8(needle[k - 1]);
for (size_t i = 0; i < n; i += 64) {
const __m256i block_first1 = _mm256_loadu_si256((const __m256i*)(s + i));
const __m256i block_last1 = _mm256_loadu_si256((const __m256i*)(s + i + k - 1));
const __m256i block_first2 = _mm256_loadu_si256((const __m256i*)(s + i + 32));
const __m256i block_last2 = _mm256_loadu_si256((const __m256i*)(s + i + k - 1 + 32));
const __m256i eq_first1 = _mm256_cmpeq_epi8(first, block_first1);
const __m256i eq_last1 = _mm256_cmpeq_epi8(last, block_last1);
const __m256i eq_first2 = _mm256_cmpeq_epi8(first, block_first2);
const __m256i eq_last2 = _mm256_cmpeq_epi8(last, block_last2);
uint32_t mask1 = _mm256_movemask_epi8(_mm256_and_si256(eq_first1, eq_last1));
uint32_t mask2 = _mm256_movemask_epi8(_mm256_and_si256(eq_first2, eq_last2));
uint64_t mask = mask1 | ((uint64_t)mask2 << 32);
while (mask != 0) {
int bitpos = __builtin_ctzll(mask);
if (memcmp(s + i + bitpos + 1, needle + 1, k - 2) == 0) {
return i + bitpos;
}
mask ^= mask & (-mask);
}
}
return n;
}
WojciechMula commented
Done, thanks