树链剖分就是把一棵树拆成几条链来处理,便于线段树进行区间操作。
引入
我们知道区间操作通常用线段树。但是如果要对一棵树上的路径或子树进行操作呢?这就需要把树拆分成链来处理,然后用线段树来维护这些链。
通常说的树链剖分实际上是指重链剖分,就是以树的重链为基础来对树进行拆分,拆分成几个重链。
实际上就是对树进行标号,使得那些路径上的点标号连续。这样就便于线段树进行区间操作了。
首先需要明白的几个姿势
- 重(zhòng)儿子:一个非叶子节点的儿子中,子节点最多的那个儿子。
- 重边(不是chóng边):一个非叶子节点和它重儿子之间的连边。
- 重链:相邻重(zhòng)边连接起来形成的一条链叫重链。
- 轻儿子、轻边:就是剩下的儿子和边。
需要注意的是,每一条重链的起点其实是轻儿子或根节点。
对于是轻儿子同时也是叶子节点的节点来说,它自身就是一条长度为 1 的重链。
树链剖分
有了上面的前置姿势,下面就要正式介绍树链剖分啦。
对节点进行标号
需要两个 DFS。
第一次 DFS 统计深度和子树大小,第二次 DFS 对树进行正式的标号。
需要拿一些数组存下这些信息,我们这样规定:
siz
记录子树大小fa
记录每个节点的父亲节点dep
记录每个节点的深度
以上就是第一遍 DFS 所需记录的信息。
所以第一遍 DFS 很简单,只需要这样写。
void dfs1(int u) {
siz[u] = 1, dep[u] = dep[fa[u]] + 1; // 记录子树大小、深度
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u]) {
fa[v] = u; // 记录父亲节点
dfs1(v);
siz[u] += siz[v]; // 统计子树大小
}
}
}
重头戏在第二遍 DFS 上。
我们还需要额外的数组来记录信息。
pos
记录经过第二遍 DFS 后节点的标号top
记录每个节点所在重链上的顶端节点
需要注意的是根节点的 top
值是它本身,记得初始化。
所以第二遍 DFS 这样写
void dfs2(int u) {
int maxn = 0, nxt = -1; // nxt 记录重儿子
pos[u] = ++num; // 记录节点的新标号
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && siz[v] > maxn) // 找出一个节点子节点中子树最大的节点,说明它是重儿子
maxn = siz[v], nxt = v;
}
if (nxt == -1) return; // 如果它没有重儿子就退出(说明已经到了叶子节点)
top[nxt] = top[u]; // 同一条重链的顶端节点相同
dfs2(nxt); // 递归优先处理重儿子
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && v != nxt) {
top[v] = v; // 对于非重儿子再进行处理,新重链的起点即它本身
dfs2(v); // 找出新的重链
}
}
}
这样进行两遍 DFS 中,我们已经完成了链剖分的全部操作,成功地把树拆分成了几条链。
显然同一条重链上的节点的标号是连续的,这样便于我们进行区间操作。
路径操作
对于同一条链上的两点,它们的标号在线段树上是连续的,我们可以很轻易地进行区间操作来实现路径操作。
但更多的情况是不在同一条链上,所以需要一些神奇的操作。
就那路径上的节点求和为例吧。
我们先取出深度较深的那个点,然后求出这个点到它所在重链的顶端节点的和,然后再把这个点更新为它所在重链顶端节点的父亲节点,这样就又调到了顶端节点上面的那条重链上。再次取出深度较深的那个点,如此往复,直到两点在同一条重链上,这样直接进行区间操作即可。
写成伪(Python)代码就是
def sum(u, v):
ans = 0
while top[u] != top[v]:
if dep[u] > dep[v]:
u, v = v, u # 交换变量,C++ 是 std::swap(u, v);
ans += query_sum(pos[top[v]], pos[v]) # 这是线段树的区间求和操作
v = fa[top[v]]
if pos[u] > pos[v]:
u, v = v, u
ans += query_sum(pos[u], pos[v])
return ans
区间更新也是类似的操作。把求和替换成线段树更新即可。
子树操作
对于每一个子树来说,它们的标号也是连续的(因为是 DFS)。
直接进行区间操作即可。对于以 u 为根的子树,它在区间上的右端点标号为 pos[u] + siz[u] - 1
求 LCA(最近公共祖先)
树链剖分是可以在线求 LCA 的,而且实测比倍增跑得快。
只要两个点在同一条重链上,那么它们的 LCA 一定是深度小的那个节点。
那么如果它们不在一条重链上呢?参考上面路径操作的方法。往上跳直到跳到同一条重链上为止。
伪(Python)代码
def sum(u, v):
while top[u] != top[v]:
if dep[u] > dep[v]:
u, v = v, u
v = fa[top[v]]
if dep[u] > dep[v]:
return v
return u
时间复杂度
(由于本人太蒟蒻了不会证,所以以下内容是网上抄的)
性质一
如果边 $(u,v)$ 为轻边,那么 $size(v)\leq size(u)/2$。
性质二
树中任意两个节点之间的路径中轻边的条数不会超过 $\log _{2}n$ ,重路径的数目不会超过 $\log _{2}n$。
根据以上两点性质以及线段树查询和修改的复杂度 $O(\log_2n)$,可以得知总复杂度为 $O(\log_2^2n)$。
例题
模板题
操作:路径更新与求和,子树更新与求和。
完整代码
#include <bits/stdc++.h>
using namespace std;
const int maxN = 1e5 + 3;
const int INF = 0x3f3f3f3f;
struct Edge {
int next, to;
} edge[maxN << 1];
struct node {
int sum, lazy;
} tree[maxN << 2];
int dep[maxN], siz[maxN], fa[maxN];
int pos[maxN], top[maxN], head[maxN], w[maxN], val[maxN];
int cnt, num, mod, n, m, root;
void add(int from, int to) {
edge[++cnt] = (Edge) {head[from], to};
head[from] = cnt;
}
void dfs1(int u) {
siz[u] = 1, dep[u] = dep[fa[u]] + 1;
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u]) {
fa[v] = u;
dfs1(v);
siz[u] += siz[v];
}
}
}
void dfs2(int u) {
int maxn = 0, nxt = -1; // nxt 记录重儿子
pos[u] = ++num;
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && siz[v] > maxn)
maxn = siz[v], nxt = v;
}
if (nxt == -1) return;
top[nxt] = top[u];
dfs2(nxt); // 递归重儿子
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && v != nxt) {
top[v] = v;
dfs2(v);
}
}
}
struct SegmentTree {
#define ls (o << 1)
#define rs (o << 1 | 1)
#define mid ((l + r) >> 1)
void push_up(int o) {
tree[o].sum = (tree[ls].sum + tree[rs].sum) % mod;
}
void build(int o, int l, int r) {
if (l == r) {
tree[o].sum = w[l];
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(o);
}
void push_down(int o, int l, int r) {
tree[ls].sum = (tree[ls].sum + tree[o].lazy * (mid - l + 1)) % mod;
tree[rs].sum = (tree[rs].sum + tree[o].lazy * (r - mid)) % mod;
tree[ls].lazy = (tree[ls].lazy + tree[o].lazy) % mod;
tree[rs].lazy = (tree[rs].lazy + tree[o].lazy) % mod;
tree[o].lazy = 0;
}
void update(int o, int l, int r, int sl, int sr, int k) {
if (sl > r || sr < l) return;
if (sl <= l && sr >= r) {
tree[o].sum = (tree[o].sum + k * (r - l + 1)) % mod;
tree[o].lazy = (tree[o].lazy + k) % mod;
return;
}
push_down(o, l, r);
if (sl <= mid) update(ls, l, mid, sl, sr, k);
if (sr > mid) update(rs, mid + 1, r, sl, sr, k);
push_up(o);
}
int query(int o, int l, int r, int sl, int sr) {
if (sl <= l && sr >= r) return tree[o].sum % mod;
push_down(o, l, r);
int ans = 0;
if (sl <= mid) ans += query(ls, l, mid, sl, sr);
if (sr > mid) ans += query(rs, mid + 1, r, sl, sr);
return ans % mod;
}
} T;
void update_path(int u, int v, int k) {
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) swap(u, v);
T.update(1, 1, num, pos[top[v]], pos[v], k);
v = fa[top[v]];
}
if (pos[u] > pos[v]) swap(u, v);
T.update(1, 1, num, pos[u], pos[v], k);
}
int query_path(int u, int v) {
int ans = 0;
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) swap(u, v);
ans = (ans + T.query(1, 1, num, pos[top[v]], pos[v])) % mod;
v = fa[top[v]];
}
if (pos[u] > pos[v]) swap(u, v);
ans = (ans + T.query(1, 1, num, pos[u], pos[v])) % mod;
return ans;
}
void update_tree(int u, int k) {
T.update(1, 1, num, pos[u], pos[u] + siz[u] - 1, k);
}
int query_tree(int u) {
return T.query(1, 1, num, pos[u], pos[u] + siz[u] - 1) % mod;
}
int main() {
memset(head, -1, sizeof(head));
scanf("%d%d%d%d", &n, &m, &root, &mod);
for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
for (int i = 1, a, b; i < n; i++) {
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs1(root); top[root] = root; dfs2(root);
for (int i = 1; i <= n; i++) w[pos[i]] = val[i];
T.build(1, 1, num);
for (int i = 1; i <= m; i++) {
int x, y, k, opt;
scanf("%d%d", &opt, &x);
if (opt == 1) {
scanf("%d%d", &y, &k);
update_path(x, y, k);
} else if (opt == 2) {
scanf("%d", &y);
printf("%d\n", query_path(x, y));
} else if (opt == 3) {
scanf("%d", &k);
update_tree(x, k);
} else printf("%d\n", query_tree(x));
}
return 0;
}
[ZJOI2008]树的统计
操作:单点修改,查询路径最大值与和。
完整代码
#include <bits/stdc++.h>
using namespace std;
const int maxN = 1e5 + 10;
const int INF = 0x7f7f7f7f;
struct Edge {
int next, to;
} edge[maxN << 1];
struct node {
int max, sum;
} st[maxN << 2];
int head[maxN];
int pos[maxN], top[maxN], siz[maxN], dep[maxN], fa[maxN];
int w[maxN];
int n, q, cnt, num = 0;
string opt;
void add(int from, int to) {
edge[++cnt] = (Edge) {head[from], to};
head[from] = cnt;
}
void dfs1(int u) {
siz[u] = 1, dep[u] = dep[fa[u]] + 1;
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u]) {
fa[v] = u;
dfs1(v);
siz[u] += siz[v];
}
}
}
void dfs2(int u) {
int maxn = 0, nxt = -1;
pos[u] = ++num;
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && siz[v] > maxn)
maxn = siz[v], nxt = v;
}
if (nxt == -1) return;
top[nxt] = top[u];
dfs2(nxt);
for (int i = head[u]; ~i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && v != nxt) {
top[v] = v;
dfs2(v);
}
}
}
struct segmentTree {
#define ls (o << 1)
#define rs (o << 1 | 1)
#define mid ((l + r) >> 1)
void build(int o, int l, int r) {
if (l == r) {
st[o].sum = st[o].max = w[l];
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
st[o].max = max(st[ls].max, st[rs].max);
st[o].sum = st[ls].sum + st[rs].sum;
}
void update(int o, int l, int r, int x, int k) {
if (l == r) {
st[o].sum = st[o].max = k;
return;
}
if (x <= mid) update(ls, l, mid, x, k);
if (x > mid) update(rs, mid + 1, r, x, k);
st[o].max = max(st[ls].max, st[rs].max);
st[o].sum = st[ls].sum + st[rs].sum;
}
int query_sum(int o, int l, int r, int sl, int sr) {
if (sr < l || r < sl) return 0;
if (sl <= l && r <= sr) return st[o].sum;
int sum = 0;
if (sl <= mid) sum += query_sum(ls, l, mid, sl, sr);
if (sr > mid) sum += query_sum(rs, mid + 1, r, sl, sr);
return sum;
}
int query_max(int o, int l, int r, int sl, int sr) {
if (sr < l || r < sl) return 0;
if (sl <= l && r <= sr) return st[o].max;
int maxn = -INF;
if (sl <= mid) maxn = max(maxn, query_max(ls, l, mid, sl, sr));
if (sr > mid) maxn = max(maxn, query_max(rs, mid + 1, r, sl, sr));
return maxn;
}
} T;
int find_sum(int a, int b) {
int ans = 0;
while (top[a] != top[b]) {
if (dep[top[a]] > dep[top[b]]) swap(a, b);
ans += T.query_sum(1, 1, num, pos[top[b]], pos[b]);
b = fa[top[b]];
}
if (pos[a] > pos[b]) swap(a, b);
ans += T.query_sum(1, 1, num, pos[a], pos[b]);
return ans;
}
int find_max(int a, int b) {
int ans = -INF;
while (top[a] != top[b]) {
if (dep[top[a]] > dep[top[b]]) swap(a, b);
ans = max(ans, T.query_max(1, 1, num, pos[top[b]], pos[b]));
b = fa[top[b]];
}
if (pos[a] > pos[b]) swap(a, b);
ans = max(ans, T.query_max(1, 1, num, pos[a], pos[b]));
return ans;
}
int main() {
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1, a, b; i < n; i++) {
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs1(1); top[1] = 1; dfs2(1);
for (int i = 1; i <= n; i++) scanf("%d", &w[pos[i]]);
T.build(1, 1, num);
scanf("%d", &q);
for (int i = 1, a, b; i <= q; i++) {
cin >> opt;
scanf("%d%d", &a, &b);
if (opt == "CHANGE") T.update(1, 1, num, pos[a], b);
else if (opt == "QMAX")
printf("%d\n", find_max(a, b));
else if (opt == "QSUM")
printf("%d\n", find_sum(a, b));
}
return 0;
}