crazySegmentTree: Segment Tree implementation with 5x faster queries than bottom-up tree
Разница между en6 и en7, 0 символ(ов) изменены
Hello, Codeforces!↵

A few days ago [user:MohammadParsaElahimanesh,2021-04-06] posted a blog titled [Can we find each Required node in segment tree in O(1)?](https://codeforces.net/blog/entry/89377) Apparently what they meant was to find each node in $\mathcal{O}(ans)$, according to [ecnerwala's explanation](https://codeforces.net/blog/entry/89377?#comment-777885). But I was too dumb to realize that and accidentally invented a parallel node resolution method instead, which speeds up segment tree a lot.↵

A benchmark for you first, with 30 million RMQ on a 32-bit integer array of 17 million elements. It was run in custom test on Codeforces on Apr 6, 2021.↵

- **Classic implementation from cp-algorithms:** 7.765 seconds, or 260 ns per query↵
- **Optimized classic implementation:** (which I was taught) 4.452 seconds, or 150 ns per query (75% faster than classic)↵
- **Bottom-up implementation:** 1.914 seconds, or 64 ns per query (133% faster than optimized)↵
- **Novel parallel implementation:** 0.383 seconds, or 13 ns per query (400% faster than bottom-up, or 2000% faster than classic implementation)↵

[cut]↵

<!-- shut up you stupid parser -->↵

## FAQ↵

**Q: Is it really that fast?** I shamelessly stole someone's solution for [problem:1355C] which uses prefix sums: [submission:112167743]. It runs in 46 ms. Then I replaced prefix sums with classic segment tree in [submission:112168469] which runs in 155 ms. The bottom-up implementation runs in 93 ms: [submission:112168530]. Finally, my novel implementation runs in only 62 ms: [submission:112168574]. Compared to the original prefix sums solution, the bottom-up segment tree uses 47 ms in total, and the parallel implementation uses only 16 ms in total. Thus, even in such a simple problem with only prefix queries the novel implementation is 3x faster than the state of art even in practice!↵

**Q: Why?** Maybe you want your $\mathcal{O}(n \log^2 n)$ solution to pass in a problem with $\mathcal{O}(n \log n)$ model solution. Maybe you want to troll problemsetters. Maybe you want to obfuscate your code so that no one would understand you used a segment tree so that no one hacks you *(just kidding, you'll get FST anyway)*. Choose an excuse for yourself. <font color="#f0f0f0">I want contribution too.</font>↵

**Q: License?** Tough question because we're in CP. So you may use it under MIT license for competitive programming, e.g. on Codeforces, and under GPLv3 otherwise.↵

**Q: Any pitfalls?** Yes, sadly. It requires AVX2 instructions which are supported on Codeforces, but may not be supported on other judges.↵


## How it works↵

![Segment tree](https://www.researchgate.net/profile/Nicolae-Tapus/publication/45924823/figure/fig1/AS:307413849788422@1450304581983/A-1D-segment-tree-with-16-leaves-and-a-canonical-decomposition-of-the-range-3-11.png)↵

In a segment tree, a range query is decomposed into red nodes. Classic segment tree implementations don't find these red nodes directly, but execute recursively on green nodes. Bottom-up segment tree implementation does enumerate these nodes directly, but it also enumerates a few other unused nodes.↵

The parallel implementation is an optimization of bottom-up tree. Probably you all know how bottom-up implementation looks like, but I'll cite the main idea nevertheless to show the difference between bottom-up and parallel implementations:↵

In bottom-up segment tree, we find the node corresponding to the leftmost element of the query, i.e. $x[l]$, and the node corresponding to the rightmost query element, i.e. $x[r]$. If we numerate nodes in a special way, the leftmost element will correspond to node $N+l$ and the rightmost will correspond to node $N+r$. After that, the answer is simply the sum of values of all nodes between $N+l$ and $N+r$. Sadly there are $\mathcal{O}(n)$ of those, but we can do the following optimization:↵

If $N+l$ is the left child of its parent and $N+r$ is the right child of its parent, then instead of summing up all nodes in range $[N+l, N+r]$, we can sum up all nodes in range $[\frac{N+l}{2}, \frac{N+r-1}{2}]$. That is, we replace the two nodes with their two parents. Otherwise, if $N+l$ is the right child of its parent, we do `ans += a[N+l]; l++;`, and if $N+r$ is the left child of its parent, we do `ans += a[N+r]; r--;` Then the condition holds and we can do the replacement.↵

In parallel segment tree, we jump to i-th parent of $N+l$ and $N+r$ for all $i$ *simultaneously*, and check the is-left/right-child conditions in parallel as well. The checks are rather simple, so a few bit operations do the trick. We can perform all bitwise operations using AVX2 on 8 integers at once, which means that the core of the query should run about 8 times faster.↵


## Want code? We have some!↵

This is the benchmark, along with the four segment tree implementations I checked and a prefix sum for comparison.↵

<spoiler summary="Benchmark code">↵
```cpp↵
#include <iostream>↵
#include <random>↵
#include <ctime>↵
#include <cassert>↵
#include <immintrin.h>↵


const int Q = 30000000;↵
const int N = 1 << 24;↵


using T = uint32_t;↵
T a[2 * N];↵
T pref[N];↵


const T identity_element = 0;↵
T reduce(T a, T b) {↵
return a + b;↵
}↵
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {↵
return _mm_add_epi32(a, b);↵
}↵
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {↵
return _mm256_add_epi32(a, b);↵
}↵

static_assert(sizeof(T) == 4, "Segment tree elements must be 32-bit");↵


T query_recursive_inner(int v, int vl, int vr, int l, int r) {↵
if(l >= r) {↵
return identity_element;↵
}↵
if(l <= vl && vr <= r) {↵
return a[v];↵
}↵
int vm = (vl + vr) / 2;↵
return reduce(query_recursive_inner(v * 2, vl, vm, l, std::min(r, vm)), query_recursive_inner(v * 2 + 1, vm, vr, std::max(l, vm), r));↵
}↵
T query_recursive_inner(int l, int r) {↵
return query_recursive_inner(1, 0, N, l, r + 1);↵
}↵


T query_recursive_outer(int v, int vl, int vr, int l, int r) {↵
if(vl == l && vr == r) {↵
return a[v];↵
} else {↵
int vm = (vl + vr) / 2;↵
if(r <= vm) {↵
return query_recursive_outer(v * 2, vl, vm, l, r);↵
} else if(l >= vm) {↵
return query_recursive_outer(v * 2 + 1, vm, vr, l, r);↵
} else {↵
return reduce(query_recursive_outer(v * 2, vl, vm, l, vm), query_recursive_outer(v * 2 + 1, vm, vr, vm, r));↵
}↵
}↵
}↵
T query_recursive_outer(int l, int r) {↵
return query_recursive_outer(1, 0, N, l, r + 1);↵
}↵


T query_bottom_up(int l, int r) {↵
l += N;↵
r += N;↵
T ans = identity_element;↵
while(l <= r) {↵
if(l & 1) {↵
ans = reduce(ans, a[l]);↵
l++;↵
}↵
if(!(r & 1)) {↵
ans = reduce(ans, a[r]);↵
r--;↵
}↵
l /= 2;↵
r /= 2;↵
}↵
return ans;↵
}↵


int ffs(unsigned int x) {↵
return sizeof(unsigned int) * 8 - 1 - __builtin_clz(x);↵
}↵


__attribute__((target("avx2"))) T query_parallel(int l, int r) {↵
if(l == r) {↵
return a[l + N];↵
}↵

int mbit = ffs(l ^ r);↵
int reset = ((1 << mbit) - 1);↵
int m = r & ~reset;↵

using vecint = T __attribute__((vector_size(32)));↵
__m256i identity_vec = _mm256_set1_epi32(identity_element);↵
vecint vec_ans = (vecint)identity_vec;↵
__m256i indexes = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);↵

if((l & reset) != 0) {↵
int ll = l - 1 + N;↵
int rr = m - 1 + N;↵

int modbit = 0;↵
int maxmodbit = ffs(ll ^ rr) + 1;↵

vecint ll_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(ll), indexes);↵

#define LOOP(content) if(modbit + 8 <= maxmodbit) { \↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4)); \↵
ll_vec >>= 8; \↵
modbit += 8; \↵
content \↵
}↵
LOOP(LOOP(LOOP(LOOP())))↵
#undef LOOP↵

__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4);↵
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));↵
} else {↵
vec_ans[0] = reduce(vec_ans[0], a[(l + N) >> mbit]);↵
}↵

if((r & reset) != reset) {↵
int ll = m + N;↵
int rr = r + 1 + N;↵

int modbit = 0;↵
int maxmodbit = ffs(ll ^ rr) + 1;↵

vecint rr_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(rr), indexes);↵

#define LOOP(content) if(modbit + 8 <= maxmodbit) { \↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4)); \↵
rr_vec >>= 8; \↵
modbit += 8; \↵
content \↵
}↵
LOOP(LOOP(LOOP(LOOP())))↵
#undef LOOP↵

__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4);↵
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));↵
} else {↵
vec_ans[0] = reduce(vec_ans[0], a[(r + N) >> mbit]);↵
}↵

// vec_ans = 7 6 5 4 3 2 1 0↵
__m128i low128 = _mm256_castsi256_si128((__m256i)vec_ans); // 3 2 1 0↵
__m128i high128 = _mm256_extractf128_si256((__m256i)vec_ans, 1); // 7 6 5 4↵
__m128i ans128 = reduce(low128, high128); // 7+3 6+2 5+1 4+0↵
T ans = identity_element;↵
for(int i = 0; i < 4; i++) {↵
ans = reduce(ans, ((T __attribute__((vector_size(16))))ans128)[i]);↵
}↵
return ans;↵
}↵


T query_prefix(int l, int r) {↵
return pref[r] - (l == 0 ? 0 : pref[l - 1]);↵
}↵


int main() {↵
std::pair<int, int>* queries = new std::pair<int, int>[Q];↵
for(int i = 0; i < Q; i++) {↵
int l = rand() % N;↵
int r = rand() % N;↵
if(l > r) {↵
std::swap(l, r);↵
}↵
queries[i] = {l, r};↵
}↵
for(int i = 0; i < N; i++) {↵
a[N + i] = rand();↵
}↵
for(int i = N - 1; i >= 1; i--) {↵
a[i] = reduce(a[i * 2], a[i * 2 + 1]);↵
}↵
a[0] = identity_element;↵

for(int i = 0; i < N; i++) {↵
pref[i] = (i == 0 ? 0 : pref[i - 1]) + a[N + i];↵
}↵

#define CHECK(func) { \↵
auto clock_start = clock(); \↵
T checksum = 0; \↵
for(int i = 0; i < Q; i++) { \↵
checksum += func(queries[i].first, queries[i].second); \↵
} \↵
std::cout << #func << ": " << (double)(clock() - clock_start) / CLOCKS_PER_SEC << " seconds (checksum: " << checksum << ")" << std::endl; \↵
}↵

CHECK(query_recursive_inner)↵
CHECK(query_recursive_outer)↵
CHECK(query_bottom_up)↵
CHECK(query_parallel)↵
CHECK(query_prefix)↵

return 0;↵
}↵
```↵
</spoiler>↵

The core is here:↵

<spoiler summary="Main code">↵
```cpp↵
#include <iostream>↵
#include <random>↵
#include <ctime>↵
#include <cassert>↵
#include <immintrin.h>↵


const int N = 1 << 24;↵


using T = uint32_t;↵
T a[2 * N];↵


const T identity_element = 0;↵
T reduce(T a, T b) {↵
return a + b;↵
}↵
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {↵
return _mm_add_epi32(a, b);↵
}↵
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {↵
return _mm256_add_epi32(a, b);↵
}↵

static_assert((N & (N - 1)) == 0, "Segment tree size must be a power of two");↵
static_assert(sizeof(T) == 4, "Segment tree elements must be 32-bit");↵


int ffs(unsigned int x) {↵
return sizeof(unsigned int) * 8 - 1 - __builtin_clz(x);↵
}↵


// Returns sum/min/max/etc. in range [l; r], inclusive. The operation is determined by reduce()↵
__attribute__((target("avx2"))) T query_parallel(int l, int r) {↵
if(l == r) {↵
return a[l + N];↵
}↵

int mbit = ffs(l ^ r);↵
int reset = ((1 << mbit) - 1);↵
int m = r & ~reset;↵

using vecint = T __attribute__((vector_size(32)));↵
__m256i identity_vec = _mm256_set1_epi32(identity_element);↵
vecint vec_ans = (vecint)identity_vec;↵
__m256i indexes = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);↵

if((l & reset) != 0) {↵
int ll = l - 1 + N;↵
int rr = m - 1 + N;↵

int modbit = 0;↵
int maxmodbit = ffs(ll ^ rr) + 1;↵

vecint ll_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(ll), indexes);↵

#define LOOP(content) if(modbit + 8 <= maxmodbit) { \↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4)); \↵
ll_vec >>= 8; \↵
modbit += 8; \↵
content \↵
}↵
LOOP(LOOP(LOOP(LOOP())))↵
#undef LOOP↵

__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4);↵
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));↵
} else {↵
vec_ans[0] = reduce(vec_ans[0], a[(l + N) >> mbit]);↵
}↵

if((r & reset) != reset) {↵
int ll = m + N;↵
int rr = r + 1 + N;↵

int modbit = 0;↵
int maxmodbit = ffs(ll ^ rr) + 1;↵

vecint rr_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(rr), indexes);↵

#define LOOP(content) if(modbit + 8 <= maxmodbit) { \↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4)); \↵
rr_vec >>= 8; \↵
modbit += 8; \↵
content \↵
}↵
LOOP(LOOP(LOOP(LOOP())))↵
#undef LOOP↵

__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4);↵
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);↵
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));↵
} else {↵
vec_ans[0] = reduce(vec_ans[0], a[(r + N) >> mbit]);↵
}↵

// vec_ans = 7 6 5 4 3 2 1 0↵
__m128i low128 = _mm256_castsi256_si128((__m256i)vec_ans); // 3 2 1 0↵
__m128i high128 = _mm256_extractf128_si256((__m256i)vec_ans, 1); // 7 6 5 4↵
__m128i ans128 = reduce(low128, high128); // 7+3 6+2 5+1 4+0↵
T ans = identity_element;↵
for(int i = 0; i < 4; i++) {↵
ans = reduce(ans, ((T __attribute__((vector_size(16))))ans128)[i]);↵
}↵
return ans;↵
}↵


int main() {↵
// ...fill array from a[N] to a[2*N-1]...↵

for(int i = N - 1; i >= 1; i--) {↵
a[i] = reduce(a[i * 2], a[i * 2 + 1]);↵
}↵
a[0] = identity_element;↵

// ...your code here...↵

return 0;↵
}↵
```↵
</spoiler>↵

The following line the count of elements in segment tree. It must be a power of two, so instead of using `1e6` use `1 << 20`:↵

```cpp↵
const int N = 1 << 24;↵
```↵

The following line sets the type of the elements. It *must* be a 32-bit integer, either signed or unsigned, at the moment.↵

```cpp↵
using T = uint32_t;↵
```↵

The following line sets the identity element. It's 0 for sum, -inf for max, inf for min. If you use unsigned integers, I'd recommend you to use 0 for max and `(uint32_t)-1` for min.↵

```cpp↵
const T identity_element = 0;↵
```↵

The following function defines the operation itself: sum, min, max, etc.↵

```cpp↵
T reduce(T a, T b) {↵
return a + b;↵
}↵
```↵

Then the following two functions are like `reduce(T, T)` but vectorized: for 128-bit registers and 256-bit registers. There are builtins for add: `_mm[256]_add_epi32`, min (signed): `_mm[256]_min_epi32`, max (signed): `_mm[256]_max_epi32`, min (unsigned): `_mm[256]_min_epu32`, max (unsigned): `_mm[256]_max_epu32`. You can check [Intel Intrinsics Guide](https://software.intel.com/sites/landingpage/IntrinsicsGuide/) if you're not sure.↵

```cpp↵
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {↵
return _mm_add_epi32(a, b);↵
}↵
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {↵
return _mm256_add_epi32(a, b);↵
}↵
```↵

Finally, these lines in `main()` are something you should *not* touch. They build the segment tree. Make sure to fill the array from `a[N]` to `a[N*2-1]` before building it.↵

```cpp↵
for(int i = N - 1; i >= 1; i--) {↵
a[i] = reduce(a[i * 2], a[i * 2 + 1]);↵
}↵
a[0] = identity_element;↵
```↵


## Further work↵

Implement point update queries in a similar way. This should be very fast for segment-tree-on-sum with point += queies, segment tree on minimum with point min= and alike.↵

Unfortunately BIT/fenwick tree cannot be optimized this way, it turns out 1.5x slower.↵

Contributions are welcome.↵

История

 
 
 
 
Правки
 
 
  Rev. Язык Кто Когда Δ Комментарий
en9 Английский purplesyringa 2021-04-06 21:47:16 0 (published)
en8 Английский purplesyringa 2021-04-06 21:47:06 19 (saved to drafts)
en7 Английский purplesyringa 2021-04-06 20:47:42 0 (published)
en6 Английский purplesyringa 2021-04-06 20:47:11 14
en5 Английский purplesyringa 2021-04-06 20:46:15 5601
en4 Английский purplesyringa 2021-04-06 20:32:33 7243
en3 Английский purplesyringa 2021-04-06 20:19:44 753
en2 Английский purplesyringa 2021-04-06 19:41:23 42
en1 Английский purplesyringa 2021-04-06 19:40:26 2533 Initial revision (saved to drafts)