树链剖分 学习笔记

模板及讲解
树链剖分解决树上的修改问题。
将树剖成一条条链,再用线段树、树状数组等维护

常见题型:
1、点权问题
Q:修改某些点的权进行询问。
解:直接树剖进行线段树/树状数组维护
例题:bzoj1036
2、边权问题
Q:修改某些边的权进行询问。
解:树剖后维护点权,每个点的点权为这个点到他父亲之间边权,询问时删除lca的点权即可
例题:poj2763
3、子树问题
Q:修改结点u为根的子树的点权。
解:由树剖的性质可得,树剖后结点u为根的子树在线段树上的区间是连续的一段,那么记录一个左端点和右端点即可(时间戳思想)
例题:bzoj4034

相关代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i, j, k) for (i=(j);i<=(k);i++)
#define fd(i, k, j) for (i=(k);i>=(j);i--)
#define rd(a) scanf("%d", &a)
#define rd2(a, b) scanf("%d%d", &a, &b)
#define rd3(a, b, c) scanf("%d%d%d", &a, &b, &c)
#define ms(i, j) memset(i, j, sizeof i)
#define FN2 "bzoj1036"
using namespace std;
const int MAXN = 30000 + 5;
int dep[MAXN], son[MAXN], fa[MAXN], siz[MAXN]; //深度,重儿子,父亲,子树大小
int p[MAXN], top[MAXN], pre; //在线段树中的位置,所在重链顶部,线段树当前标号
int n, wi[MAXN];
vector<int> G[MAXN];
void dfs1(int u, int f)//第一次dfs记录值
{
int i;
dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
fo (i, 0, G[u].size()-1) {
int v = G[u][i];
if (v!=f) {
dfs1(v, u);
siz[u] += siz[v];
if (son[u]==-1||siz[son[u]]<siz[v]) son[u] = v;
}
}
}
void dfs2(int u, int chain) {//第二次dfs连重儿子成重链
int i;
p[u] = ++pre, top[u] = chain;
if (son[u]!=-1) {
dfs2(son[u], chain);
fo (i, 0, G[u].size()-1) {
int v = G[u][i];
if (v!=son[u]&&v!=fa[u]) dfs2(v, v);
}
}
}
int maxv[MAXN*4], sumv[MAXN*4];
void pushup(int o) {
int lc = o*2, rc = o*2+1;
maxv[o] = max(maxv[lc], maxv[rc]);
sumv[o] = sumv[lc] + sumv[rc];
}
void update(int o, int l, int r, int p, int v) {
int lc = o*2, rc = o*2+1, M = (l+r)/2;
if (l==r) {
sumv[o] = maxv[o] = v; return ;
}
if (p<=M) update(lc, l, M, p, v); else if (M<p) update(rc, M+1, r, p, v);
pushup(o);
}
int getMax(int o, int l, int r, int x, int y) {
int lc = o*2, rc = o*2+1, M = (l+r)/2, ret = -200000000;
if (x<=l&&r<=y) {
return maxv[o];
}
if (x<=M) ret = max(ret, getMax(lc, l, M, x, y));
if (M<y) ret = max(ret, getMax(rc, M+1, r, x, y));
return ret;
}
int getSum(int o, int l, int r, int x, int y) {
int lc = o*2, rc = o*2+1, M = (l+r)/2, ret = 0;
if (x<=l&&r<=y) {
return sumv[o];
}
if (x<=M) ret += getSum(lc, l, M, x, y);
if (M<y) ret += getSum(rc, M+1, r, x, y);
return ret;
}
int findMax(int u, int v)
{
int f1 = top[u], f2 = top[v];
int ret = -200000000;
while (f1!=f2) {
if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
ret = max(ret, getMax(1, 1, n, p[f1], p[u]));
u = fa[f1], f1 = top[u];
}
if (dep[u]<dep[v]) swap(u, v);
return max(ret, getMax(1, 1, n, p[v], p[u]));
}
int findSum(int u, int v)
{
int f1 = top[u], f2 = top[v];
int ret = 0;
while (f1!=f2) {
if (dep[f1]<dep[f2]) swap(f1, f2), swap(u, v);
ret += getSum(1, 1, n, p[f1], p[u]);
u = fa[f1], f1 = top[u];
}
if (dep[u]<dep[v]) swap(u, v);
return ret+getSum(1, 1, n, p[v], p[u]);
}
void init() {
int i; pre = 0;
fo (i, 1, n) dep[i] = fa[i] = siz[i] = p[i] = top[i] = 0, son[i] = -1, G[i].clear();
fo (i, 1, n*4) maxv[i] = -200000000, sumv[i] = 0;
fo (i, 1, n-1) {
int a, b; rd2(a, b);
G[a].push_back(b), G[b].push_back(a);
}
}
void solve() {
int q, i;
dfs1(1, 0);
dfs2(1, 1);
fo (i, 1, n) rd(wi[i]), update(1, 1, n, p[i], wi[i]);
rd(q);
fo (i, 1, q) {
char ch[10]; scanf("%s", ch);
if (ch[0]=='C') {
int u, t; rd2(u, t), update(1, 1, n, p[u], t);
} else if (ch[1]=='M') {
int u, v; rd2(u, v), printf("%d\n", findMax(u, v));
} else if (ch[1]=='S') {
int u, v; rd2(u, v), printf("%d\n", findSum(u, v));
}
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen(FN2".in","r",stdin);freopen(FN2".out","w",stdout);
#endif
while (rd(n)==1) init(), solve();
return 0;
}

------ 本文结束 ------