线段树是一种很实用的数据结构,所以也要学习一下。
线段树是用来处理区间问题的,复杂度能到达 $O(\log n)$。
就以洛谷的模板题 线段树 1 为例吧。
1.将某区间每一个数加上x
2.求出某区间每一个数的和
很显然,这道题是要写一个维护区间和的线段树。
就拿代码说吧。
建树
因为线段树是一棵二叉树,所以如果一个结点的编号为 $i$ 那么其左右子结点编号分别为 $2i$ 和 $2i+1$(用位运算可以如下表示)
#define ls (o << 1)
#define rs (o << 1 | 1)
因为要维护区间和,所以要这么写。
void push_up(long long o) {
sum[o].v = sum[ls].v + sum[rs].v;
}
递归建树。
如果左右区间相同说明它是叶子节点,所以就记录输入数据的信息然后 return
就好啦。
void build(long long o, long long l, long long r) {
sum[o].add = 0;
if (l == r) {
sum[o].v = a[l];
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(o);
}
区间修改
我们用 add
记录每个节点每次要更新的值,传递式记录,这样有利于减少复杂度。
注意:add
在建树时就已经初始化好了,其实在这里也可以初始化。
下面用 delta()
函数来进行区间更新操作,因为是对区间进行操作,所以最后要乘上区间长度(即区间元素个数)。push_down()
来维护区间更新,每次更新两个子节点并向下传递。
void delta(long long o, long long l, long long r, long long k) {
sum[o].add = sum[o].add + k;
sum[o].v = sum[o].v + k * (r - l + 1);
}
void push_down(long long o, long long l, long long r) {
delta(ls, l, mid, sum[o].add);
delta(rs, mid + 1, r, sum[o].add);
}
下面的代码中,用 sb, se
表示要修改的区间,l, r
表示节点 o
所存的区间。(emm,参数有些长啊,但强迫症不想用 )typedef
void update(long long o, long long sb, long long se, long long l, long long r, long long k) {
if (sb <= l && se >= r) { // 如果到了区间端点就这么做
sum[o].v += k * (r - l + 1);
sum[o].add += k;
return;
} // 是不是很熟悉?这就是 delta() 的操作
push_down(o, l, r);
if (sb <= mid) update(ls, sb, se, l, mid, k);
if (se > mid) update(rs, sb, se, mid + 1, r, k);
push_up(o); // 回溯
}
区间查询
依然是递归,如果在范围内就返回节点存的值,如果出了范围就返回 0。
long long query(long long o, long long sb, long long se, long long l, long long r) {
if (sb <= l && se >= r) return sum[o].v;
if (sb > r || se < l) return 0;
push_down(o, l, r);
return query(ls, sb, se, l, mid) + query(rs, sb, se, mid + 1, r);
}
最终代码
#include <bits/stdc++.h>
using namespace std;
const int maxN = 1e5 + 3;
struct data {
long long v, add;
} sum[maxN << 2];
long long n, m;
long long a[maxN];
struct Tree {
#define ls (o << 1)
#define rs (o << 1 | 1)
#define mid ((r + l) >> 1)
void push_up(long long o) {
sum[o].v = sum[ls].v + sum[rs].v;
}
void build(long long o, long long l, long long r) {
sum[o].add = 0;
if (l == r) {
sum[o].v = a[l];
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(o);
}
void delta(long long o, long long l, long long r, long long k) {
sum[o].add = sum[o].add + k;
sum[o].v = sum[o].v + k * (r - l + 1);
}
void push_down(long long o, long long l, long long r) {
delta(ls, l, mid, sum[o].add);
delta(rs, mid + 1, r, sum[o].add);
}
void update(long long o, long long sb, long long se, long long l, long long r, long long k) {
if (sb <= l && se >= r) {
sum[o].v += k * (r - l + 1);
sum[o].add += k;
return;
}
push_down(o, l, r);
if (sb <= mid) update(ls, sb, se, l, mid, k);
if (se > mid) update(rs, sb, se, mid + 1, r, k);
push_up(o);
}
long long query(long long o, long long sb, long long se, long long l, long long r) {
if (sb <= l && se >= r) return sum[o].v;
if (sb > r || se < l) return 0;
push_down(o, l, r);
return query(ls, sb, se, l, mid) + query(rs, sb, se, mid + 1, r);
}
} tree;
int main() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) scanf("%lld", a + i);
tree.build(1, 1, n);
while (m--) {
int c;
long long x, y, k;
scanf("%d%lld%lld", &c, &x, &y);
if (c == 1) {
scanf("%lld", &k);
tree.update(1, x, y, 1, n, k);
} else if (c == 2)
printf("%lld\n", tree.query(1, x, y, 1, n));
}
}