树链剖分就是把一棵树拆成几条链来处理,便于线段树进行区间操作。

引入

我们知道区间操作通常用线段树。但是如果要对一棵树上的路径或子树进行操作呢?这就需要把树拆分成链来处理,然后用线段树来维护这些链。

通常说的树链剖分实际上是指重链剖分,就是以树的重链为基础来对树进行拆分,拆分成几个重链。
实际上就是对树进行标号,使得那些路径上的点标号连续。这样就便于线段树进行区间操作了。

首先需要明白的几个姿势

  • 重(zhòng)儿子:一个非叶子节点的儿子中,子节点最多的那个儿子。
  • 重边(不是chóng边):一个非叶子节点和它重儿子之间的连边。
  • 重链:相邻重(zhòng)边连接起来形成的一条链叫重链。
  • 轻儿子、轻边:就是剩下的儿子和边。

需要注意的是,每一条重链的起点其实是轻儿子或根节点。
对于是轻儿子同时也是叶子节点的节点来说,它自身就是一条长度为 1 的重链。

树链剖分

有了上面的前置姿势,下面就要正式介绍树链剖分啦。

对节点进行标号

需要两个 DFS。

第一次 DFS 统计深度和子树大小,第二次 DFS 对树进行正式的标号。

需要拿一些数组存下这些信息,我们这样规定:

  • siz 记录子树大小
  • fa 记录每个节点的父亲节点
  • dep 记录每个节点的深度

以上就是第一遍 DFS 所需记录的信息。

所以第一遍 DFS 很简单,只需要这样写。

void dfs1(int u) {
    siz[u] = 1, dep[u] = dep[fa[u]] + 1; // 记录子树大小、深度
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u]) {
            fa[v] = u; // 记录父亲节点
            dfs1(v);
            siz[u] += siz[v]; // 统计子树大小
        }
    }
}

重头戏在第二遍 DFS 上。

我们还需要额外的数组来记录信息。

  • pos 记录经过第二遍 DFS 后节点的标号
  • top 记录每个节点所在重链上的顶端节点

需要注意的是根节点的 top 值是它本身,记得初始化。

所以第二遍 DFS 这样写

void dfs2(int u) {
    int maxn = 0, nxt = -1; // nxt 记录重儿子
    pos[u] = ++num; // 记录节点的新标号
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && siz[v] > maxn) // 找出一个节点子节点中子树最大的节点,说明它是重儿子
            maxn = siz[v], nxt = v;
    }
    if (nxt == -1) return; // 如果它没有重儿子就退出(说明已经到了叶子节点)
    top[nxt] = top[u]; // 同一条重链的顶端节点相同
    dfs2(nxt); // 递归优先处理重儿子
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && v != nxt) {
            top[v] = v; // 对于非重儿子再进行处理,新重链的起点即它本身
            dfs2(v); // 找出新的重链
        }
    }
}

这样进行两遍 DFS 中,我们已经完成了链剖分的全部操作,成功地把树拆分成了几条链。

显然同一条重链上的节点的标号是连续的,这样便于我们进行区间操作。

路径操作

对于同一条链上的两点,它们的标号在线段树上是连续的,我们可以很轻易地进行区间操作来实现路径操作。

但更多的情况是不在同一条链上,所以需要一些神奇的操作。

就那路径上的节点求和为例吧。

我们先取出深度较深的那个点,然后求出这个点到它所在重链的顶端节点的和,然后再把这个点更新为它所在重链顶端节点的父亲节点,这样就又调到了顶端节点上面的那条重链上。再次取出深度较深的那个点,如此往复,直到两点在同一条重链上,这样直接进行区间操作即可。

写成伪(Python)代码就是

def sum(u, v):
    ans = 0
    while top[u] != top[v]:
        if dep[u] > dep[v]:
            u, v = v, u    # 交换变量,C++ 是 std::swap(u, v);
        ans += query_sum(pos[top[v]], pos[v]) # 这是线段树的区间求和操作
        v = fa[top[v]]
    if pos[u] > pos[v]:
        u, v = v, u
    ans += query_sum(pos[u], pos[v])
    return ans

区间更新也是类似的操作。把求和替换成线段树更新即可。

子树操作

对于每一个子树来说,它们的标号也是连续的(因为是 DFS)。

直接进行区间操作即可。对于以 u 为根的子树,它在区间上的右端点标号为 pos[u] + siz[u] - 1

求 LCA(最近公共祖先)

树链剖分是可以在线求 LCA 的,而且实测比倍增跑得快。

只要两个点在同一条重链上,那么它们的 LCA 一定是深度小的那个节点。

那么如果它们不在一条重链上呢?参考上面路径操作的方法。往上跳直到跳到同一条重链上为止。

伪(Python)代码

def sum(u, v):
    while top[u] != top[v]:
        if dep[u] > dep[v]:
            u, v = v, u
        v = fa[top[v]]
    if dep[u] > dep[v]:
        return v
    return u

时间复杂度

(由于本人太蒟蒻了不会证,所以以下内容是网上抄的)

性质一

如果边 $(u,v)$ 为轻边,那么 $size(v)\leq size(u)/2$。

性质二

树中任意两个节点之间的路径中轻边的条数不会超过 $\log _{2}n$ ,重路径的数目不会超过 $\log _{2}n$。

根据以上两点性质以及线段树查询和修改的复杂度 $O(\log_2n)$,可以得知总复杂度为 $O(\log_2^2n)$。

例题

模板题

传送门

操作:路径更新与求和,子树更新与求和。

完整代码

#include <bits/stdc++.h>

using namespace std;

const int maxN = 1e5 + 3;
const int INF = 0x3f3f3f3f;

struct Edge {
    int next, to;
} edge[maxN << 1];

struct node {
    int sum, lazy;
} tree[maxN << 2];

int dep[maxN], siz[maxN], fa[maxN];
int pos[maxN], top[maxN], head[maxN], w[maxN], val[maxN];
int cnt, num, mod, n, m, root;

void add(int from, int to) {
    edge[++cnt] = (Edge) {head[from], to};
    head[from] = cnt;
}

void dfs1(int u) {
    siz[u] = 1, dep[u] = dep[fa[u]] + 1;
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u]) {
            fa[v] = u;
            dfs1(v);
            siz[u] += siz[v];
        }
    }
}

void dfs2(int u) {
    int maxn = 0, nxt = -1; // nxt 记录重儿子
    pos[u] = ++num;
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && siz[v] > maxn)
            maxn = siz[v], nxt = v;
    }
    if (nxt == -1) return;
    top[nxt] = top[u];
    dfs2(nxt); // 递归重儿子
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && v != nxt) {
            top[v] = v;
            dfs2(v);
        }
    }
}

struct SegmentTree {
    #define ls (o << 1)
    #define rs (o << 1 | 1)
    #define mid ((l + r) >> 1)
    void push_up(int o) {
        tree[o].sum = (tree[ls].sum + tree[rs].sum) % mod;
    }
    void build(int o, int l, int r) {
        if (l == r) {
            tree[o].sum = w[l];
            return;
        }
        build(ls, l, mid);
        build(rs, mid + 1, r);
        push_up(o);
    }
    void push_down(int o, int l, int r) {
        tree[ls].sum = (tree[ls].sum + tree[o].lazy * (mid - l + 1)) % mod;
        tree[rs].sum = (tree[rs].sum + tree[o].lazy * (r - mid)) % mod;
        tree[ls].lazy = (tree[ls].lazy + tree[o].lazy) % mod;
        tree[rs].lazy = (tree[rs].lazy + tree[o].lazy) % mod;
        tree[o].lazy = 0;
    }
    void update(int o, int l, int r, int sl, int sr, int k) {
        if (sl > r || sr < l) return;
        if (sl <= l && sr >= r) {
            tree[o].sum = (tree[o].sum + k * (r - l + 1)) % mod;
            tree[o].lazy = (tree[o].lazy + k) % mod;
            return;
        }
        push_down(o, l, r);
        if (sl <= mid) update(ls, l, mid, sl, sr, k);
        if (sr > mid) update(rs, mid + 1, r, sl, sr, k);
        push_up(o);
    }
    int query(int o, int l, int r, int sl, int sr) {
        if (sl <= l && sr >= r) return tree[o].sum % mod;
        push_down(o, l, r);
        int ans = 0;
        if (sl <= mid) ans += query(ls, l, mid, sl, sr);
        if (sr > mid) ans += query(rs, mid + 1, r, sl, sr);
        return ans % mod;
    }
} T;

void update_path(int u, int v, int k) {
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) swap(u, v);
        T.update(1, 1, num, pos[top[v]], pos[v], k);
        v = fa[top[v]];
    }
    if (pos[u] > pos[v]) swap(u, v);
    T.update(1, 1, num, pos[u], pos[v], k);
}

int query_path(int u, int v) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) swap(u, v);
        ans = (ans + T.query(1, 1, num, pos[top[v]], pos[v])) % mod;
        v = fa[top[v]];
    }
    if (pos[u] > pos[v]) swap(u, v);
    ans = (ans + T.query(1, 1, num, pos[u], pos[v])) % mod;
    return ans;
}

void update_tree(int u, int k) {
    T.update(1, 1, num, pos[u], pos[u] + siz[u] - 1, k);
}

int query_tree(int u) {
    return T.query(1, 1, num, pos[u], pos[u] + siz[u] - 1) % mod;
}

int main() {
    memset(head, -1, sizeof(head));
    scanf("%d%d%d%d", &n, &m, &root, &mod);
    for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
    for (int i = 1, a, b; i < n; i++) {
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    dfs1(root); top[root] = root; dfs2(root);
    for (int i = 1; i <= n; i++) w[pos[i]] = val[i];
    T.build(1, 1, num);
    for (int i = 1; i <= m; i++) {
        int x, y, k, opt;
        scanf("%d%d", &opt, &x);
        if (opt == 1) {
            scanf("%d%d", &y, &k);
            update_path(x, y, k);
        } else if (opt == 2) {
            scanf("%d", &y);
            printf("%d\n", query_path(x, y));
        } else if (opt == 3) {
            scanf("%d", &k);
            update_tree(x, k);
        } else printf("%d\n", query_tree(x));
    }
    return 0;
}

[ZJOI2008]树的统计

传送门

操作:单点修改,查询路径最大值与和。

完整代码

#include <bits/stdc++.h>

using namespace std;

const int maxN = 1e5 + 10;
const int INF = 0x7f7f7f7f;

struct Edge {
    int next, to;
} edge[maxN << 1];

struct node {
    int max, sum;
} st[maxN << 2];

int head[maxN];
int pos[maxN], top[maxN], siz[maxN], dep[maxN], fa[maxN];
int w[maxN];
int n, q, cnt, num = 0;
string opt;

void add(int from, int to) {
    edge[++cnt] = (Edge) {head[from], to};
    head[from] = cnt;
}

void dfs1(int u) {
    siz[u] = 1, dep[u] = dep[fa[u]] + 1;
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u]) {
            fa[v] = u;
            dfs1(v);
            siz[u] += siz[v];
        }
    }
}

void dfs2(int u) {
    int maxn = 0, nxt = -1;
    pos[u] = ++num;
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && siz[v] > maxn)
            maxn = siz[v], nxt = v;
    }
    if (nxt == -1) return;
    top[nxt] = top[u];
    dfs2(nxt);
    for (int i = head[u]; ~i; i = edge[i].next) {
        int v = edge[i].to;
        if (v != fa[u] && v != nxt) {
            top[v] = v;
            dfs2(v);
        }
    }
}

struct segmentTree {
    #define ls (o << 1)
    #define rs (o << 1 | 1)
    #define mid ((l + r) >> 1)
    void build(int o, int l, int r) {
        if (l == r) {
            st[o].sum = st[o].max = w[l];
            return;
        }
        build(ls, l, mid);
        build(rs, mid + 1, r);
        st[o].max = max(st[ls].max, st[rs].max);
        st[o].sum = st[ls].sum + st[rs].sum;
    }
    void update(int o, int l, int r, int x, int k) {
        if (l == r) {
            st[o].sum = st[o].max = k;
            return;
        }
        if (x <= mid) update(ls, l, mid, x, k);
        if (x > mid) update(rs, mid + 1, r, x, k);
        st[o].max = max(st[ls].max, st[rs].max);
        st[o].sum = st[ls].sum + st[rs].sum;
    }
    int query_sum(int o, int l, int r, int sl, int sr) {
        if (sr < l || r < sl) return 0;
        if (sl <= l && r <= sr) return st[o].sum;
        int sum = 0;
        if (sl <= mid) sum += query_sum(ls, l, mid, sl, sr);
        if (sr > mid) sum += query_sum(rs, mid + 1, r, sl, sr);
        return sum;
    }
    int query_max(int o, int l, int r, int sl, int sr) {
        if (sr < l || r < sl) return 0;
        if (sl <= l && r <= sr) return st[o].max;
        int maxn = -INF;
        if (sl <= mid) maxn = max(maxn, query_max(ls, l, mid, sl, sr));
        if (sr > mid) maxn = max(maxn, query_max(rs, mid + 1, r, sl, sr));
        return maxn;
    }
} T;

int find_sum(int a, int b) {
    int ans = 0;
    while (top[a] != top[b]) {
        if (dep[top[a]] > dep[top[b]]) swap(a, b);
        ans += T.query_sum(1, 1, num, pos[top[b]], pos[b]);
        b = fa[top[b]];
    }
    if (pos[a] > pos[b]) swap(a, b);
    ans += T.query_sum(1, 1, num, pos[a], pos[b]);
    return ans;
}

int find_max(int a, int b) {
    int ans = -INF;
    while (top[a] != top[b]) {
        if (dep[top[a]] > dep[top[b]]) swap(a, b);
        ans = max(ans, T.query_max(1, 1, num, pos[top[b]], pos[b]));
        b = fa[top[b]];
    }
    if (pos[a] > pos[b]) swap(a, b);
    ans = max(ans, T.query_max(1, 1, num, pos[a], pos[b]));
    return ans;
}

int main() {
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1, a, b; i < n; i++) {
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    dfs1(1); top[1] = 1; dfs2(1);
    for (int i = 1; i <= n; i++) scanf("%d", &w[pos[i]]);
    T.build(1, 1, num);
    scanf("%d", &q);
    for (int i = 1, a, b; i <= q; i++) {
        cin >> opt;
        scanf("%d%d", &a, &b);
        if (opt == "CHANGE") T.update(1, 1, num, pos[a], b);
        else if (opt == "QMAX")
            printf("%d\n", find_max(a, b));
        else if (opt == "QSUM")
            printf("%d\n", find_sum(a, b));
    }
    return 0;
}