Recently I have been learning segment tree and so made a template for segment tree. I am not entirely sure that it is all correct and if it can be optimized further. Kindly go through it and if you find anything worth commenting like any mistakes or optimizations, please do so.
Here is the template:-
#include<bits/stdc++.h>
using namespace std;
class SegmentTree{
private :
std::vector<int> st,A;
vector<int> lazy;
int n;
public:
SegmentTree(vector<int>& a)
{
A = a;
n = a.size();
st.assign(4*n + 1, 0); // Max 4n nodes required
lazy.assign(4*n+1, 0); // Max 4n nodes required
build(1,0,n-1); // build segment tree
}
void print() // Print the st and lazy to debug
{
cout << "SegmentTree is as follows "<< endl;
for(int c: st)
{
cout << c << " ";
}
cout << "\nLazy is as follows \n";
for(int c: lazy)
{
cout << c<< " ";
}
cout << endl;
}
void build(int i, int l, int r) // Method to build the segTree
{
if(l>r)
return;
if(l == r)
{
st[i] = A[l];
}
else
{
build(2*i, l,(l+r)/2);
build(2*i + 1, (l+r)/2 + 1, r);
st[i] = st[2*i] + st[2*i + 1]; // Modify this as needed
}
}
int rsq(int l, int r) // Range Sum query.Modify this // as needed for different problems.
{
l--;
r--;
return query(1,0,n-1,l,r);
}
void update_range(int l, int r, int diff)
{
l--,r--;
update_lazy(1,0,n-1,l,r ,diff);
}
void update_lazy(int i, int a, int b, int l, int r, int diff)
{
if(lazy[i]!=0)
{
st[i] += (b-a+1)*diff; // Modify as needed
if(a!=b) // propagate if not leaf
{
lazy[2*i] = lazy[i];
lazy[2*i+1] = lazy[i];
}
lazy[i] = 0;
}
if(l>r || l>b || r<a) // Out of range
return;
if(a>=l && r<=b) // Completely in range
{
st[i] = (b-a+1)*diff;
if(a!=b) // If not leaf then propagate
{
lazy[2*i] += diff;
lazy[2*i+1] += diff;
}
return;
}
update_lazy(2*i, a, (a+b)/2, l, r, diff);
update_lazy(2*i+1, (a+b)/2+1, b, l, r, diff);
st[i] = st[2*i] + st[2*i+1]; // Modify as needed
}
int query(int i, int a,int b, int l, int r)
{
if(lazy[i]!=0)
{
st[i] += (b-a+1)*lazy[i];
if(a!=b)
{
lazy[2*i] = lazy[i];
lazy[2*i+1] = lazy[i];
}
lazy[i] = 0;
}
if(r<a || b<l || a > b)
return 0;
if(l<=a && r>=b)
return st[i];
return query(2*i, a,(a+b)/2, l,r) + query(2*i+1,(a+b)/2 + 1, b,l,r); // MOdify
}
};
int main()
{
vector<int> a(8);
for(int i=0; i<8; i++)
a[i] = i+1;
SegmentTree* sst = new SegmentTree(a);
cout << sst->rsq(1,4) << endl;
sst->update_range(1,4,2);
cout << sst->rsq(1,4);
return 0;
}
Thanks in advance!