yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P3690 【模版】动态树(Link Cut Tree)

题面

题目链接

解法

WC实在是太无聊子

今天是冬令营的第一天,我们来学习一下LCT吧(雾

引入&定义

Link Cut Tree,顾名思义,就是可以动态维护树上路径权值和、树上连边与断边操作的一种数据结构。

说到维护树上路径权值和,首先想到的应该就是树链剖分了,一般我们会选择用$O(n \log n)$的重链剖分。但是在动态树上,重链随时会改变,使用重链剖分就不行了。

这种情况下,我们不得不采用更为暴力的实链剖分。不考虑各儿子的子树大小,直接随便取一个儿子与父亲连接成一条链。于是树的形态变化就不会影响复杂度。不幸的是,这样的时间复杂度是$O(n^2)$的。

那么有没有什么数据结构可以很好地维护这样的实链呢?考虑到一条链上的节点深度都是递增的,不妨用Splay来维护每条链,而Splay的中序遍历就是整条链从头到尾的遍历。

如此我们可以将一棵树(森林)转化为若干条链,每条链用Splay来维护,我们称这样形成的树叫做辅助树(森林)。如下图中的例子(原树):

原树

其中实线表示的是随机分配的实边。

它的辅助树形态如下:

辅助树

原树中的实链都在同一棵Splay中,而原树中的虚链只有儿子指向父亲,没有父亲指向儿子。

注意:辅助树的指向不等于原树中的指向,辅助树的根也不等于原树的根!

数组&函数

首先定义要用的数组:

  • fa[x] $x$在辅助树上的父节点
  • ch[x][0/1] $x$在Splay中的左右儿子
  • val[x] $x$的权值
  • sum[x] $x$所在Splay的子树中的权值和
  • tag[x] $x$节点的翻转标记(后面会提到)

需要的函数如下:

  • pushup(x) 更新节点权值,不多说了
  • pushdown(x) 下传翻转标记
  • pushall(x) 将$x$在Splay中到根节点的路径全部下传标记
  • connect(x, f, w) Splay操作:连接$x$为$f$的$w$儿子($w=-1$表示虚边上的儿子)
  • get(x) 获取$x$是哪种儿子
  • rotate(x) Splay操作:旋转
  • splay(x) Splay操作:旋转到根
  • access(x) LCT的核心:将$x$与根放到同一条实链中
  • makeroot(x) 将$x$变成原树的根
  • split(x, y) 将$x$到$y$的路径放到同一条实链中
  • findroot(x) 找到$x$所在实链的链头
  • link(x, y) 连接$x$到$y$
  • cut(x, y) 断开$x$到$y$

代码实现

pushup(x)
1
2
3
4
5
// 更新sum
void pushup(int x) {
// 视情况而定,本题为异或和
sum[x] = sum[ch[x][0]] ^ sum[ch[x][1]] ^ val[x];
}
pushdown(x)
1
2
3
4
5
6
7
8
9
10
11
// 标记下传
void pushdown(int x) {
if (tag[x]) {
swap(ch[x][0], ch[x][1]);
if (ch[x][0])
tag[ch[x][0]] ^= 1;
if (ch[x][1])
tag[ch[x][1]] ^= 1;
tag[x] = 0;
}
}
pushall(x)
1
2
3
4
5
6
// 递归实现即可
void pushall(int x) {
if (~get(x))
pushall(fa[x]);
pushdown(x);
}
connect(x, f, w)
1
2
3
4
// 压到一行
void connect(int x, int f, int w) {
fa[x] = f, (~w) && (ch[f][w] = x);
}
get(x)
1
2
3
4
5
6
7
8
// 这个压不了
int get(int x) {
if (ch[fa[x]][0] == x)
return 0;
if (ch[fa[x]][1] == x)
return 1;
return -1;
}
rotate(x)

和普通的splay不一样了,要判断虚边

1
2
3
4
5
6
7
8
9
void rotate(int x) {
int y = fa[x], z = fa[y], w = get(x);
if (~w)
connect(ch[x][w ^ 1], y, w);
connect(x, z, get(y));
connect(y, x, w ^ 1);
pushup(y);
pushup(x);
}
splay(x)

第一行要先将上方的标记都下传

1
2
3
4
5
6
void splay(int x) {
pushall(x);
for (int f = fa[x]; f = fa[x], ~get(x); rotate(x))
if (~get(f))
rotate(get(f) == get(x) ? f : x);
}
access(x)

辅助树

考虑将上图的节点$9$与根节点$1$打通。

因此要先将$9$转到当前splay的根(如图)。

splay后的辅助树

而根据辅助树上链的性质,$9$、$13$和$10$不应该在一条链上,于是将$9$与$13$断开成虚链(从父亲中移除指向即可),同时将下方与$x$连接的另一棵splay连接(ch[x][1]:=pre)。

以此类推,一直往上重复此操作直到根。

1
2
3
4
5
6
7
8
9
10
void access(int x) {
int pre = 0;
while (x) {
splay(x);
ch[x][1] = pre;
pushup(x);
pre = x;
x = fa[x];
}
}
makeroot(x)

让$x$成为新的原树的根。将$x$转到根节点的splay中,然后让$x$成为根节点。

但是这样一来,$x$所在的splay树就乱掉了,因为无法再保证深度递增。同时发现由于在access(x)之后$x$的右儿子被清空了,所以将该splay翻转就得到新的按深度递增的splay(仔细想想是不是)。于是代码就出来了:

1
2
3
4
5
void makeroot(int x) {
access(x);
splay(x);
tag[x] ^= 1; // 这是翻转的标记
}
split(x, y)

将$x$成为原树的根(同时是该splay的深度最小节点),再将$y$与$x$打通,旋转$y$到该splay的根,显然$y$的当前$sum$值就是$x$到$y$路径的$sum$。

1
2
3
4
5
void split(int x, int y) {
makeroot(x);
access(y);
splay(y);
}
findroot(x)

将$x$与根打通(access(x)),显然splay最左边的节点就是根了。

1
2
3
4
5
6
7
8
int findroot(int x) {
access(x);
splay(x);
while (ch[x][0])
pushdown(x), x = ch[x][0];
splay(x);
return x;
}

先判断$x$与$y$是否联通,类似并查集的方法。

然后$y$转到根,将$x$指向$y$(注意不要让父亲$y$指向儿子$x$)。

1
2
3
4
5
6
7
bool link(int x, int y) {
makeroot(x);
if (findroot(y) == x)
return false;
fa[x] = y;
return true;
}
cut(x, y)

还是先判断$x$与$y$在同一个联通块。

将$x$到$y$的路径取出,并判断$x$与$y$是否直接连接。

如果直接连接,直接断开即可。

1
2
3
4
5
6
7
8
9
bool cut(int x, int y) {
if (findroot(x) != findroot(y))
return false;
split(x, y);
if (fa[x] != y || ch[x][1])
return false;
fa[x] = ch[y][0] = 0;
return true;
}

至此,LCT的各函数实现就讲完了,下面放出一个封装好的模版:

模版

理论时间复杂度是$O(n \log n)$(我也不会证明),但是常数极大(听说能有$20$?),慎用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
const int MAXN = 1e5 + 10;
struct LCT {
int fa[MAXN], ch[MAXN][2];
int val[MAXN], sum[MAXN], tag[MAXN];
void pushup(int x) { sum[x] = sum[ch[x][0]] ^ sum[ch[x][1]] ^ val[x]; }
void connect(int x, int f, int w) { fa[x] = f, (~w) && (ch[f][w] = x); }
int get(int x) {
if (ch[fa[x]][0] == x)
return 0;
if (ch[fa[x]][1] == x)
return 1;
return -1;
}
void pushdown(int x) {
if (tag[x]) {
swap(ch[x][0], ch[x][1]);
if (ch[x][0])
tag[ch[x][0]] ^= 1;
if (ch[x][1])
tag[ch[x][1]] ^= 1;
tag[x] = 0;
}
}
void pushall(int x) {
if (~get(x))
pushall(fa[x]);
pushdown(x);
}
void rotate(int x) {
int y = fa[x], z = fa[y], w = get(x);
if (~w)
connect(ch[x][w ^ 1], y, w);
connect(x, z, get(y));
connect(y, x, w ^ 1);
pushup(y);
pushup(x);
}
void splay(int x) {
pushall(x);
for (int f = fa[x]; f = fa[x], ~get(x); rotate(x))
if (~get(f))
rotate(get(f) == get(x) ? f : x);
}
void access(int x) {
int pre = 0;
while (x) {
splay(x);
ch[x][1] = pre;
pushup(x);
pre = x;
x = fa[x];
}
}
void makeroot(int x) {
access(x);
splay(x);
tag[x] ^= 1;
}
void split(int x, int y) {
makeroot(x);
access(y);
splay(y);
}
int findroot(int x) {
access(x);
splay(x);
while (ch[x][0])
pushdown(x), x = ch[x][0];
splay(x);
return x;
}
bool link(int x, int y) {
makeroot(x);
if (findroot(y) == x)
return false;
fa[x] = y;
return true;
}
bool cut(int x, int y) {
if (findroot(x) != findroot(y))
return false;
split(x, y);
if (fa[x] != y || ch[x][1])
return false;
fa[x] = ch[y][0] = 0;
return true;
}
};

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/**
* @file: P3690.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P3690
*/
// #pragma GCC optimize ("O2")
// #pragma GCC optimize ("Ofast", "inline", "-ffast-math")
// #pragma GCC target ("avx,sse2,sse3,sse4,mmx")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp& x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
const int MAXN = 1e5 + 10;
const int INF = 0x3f3f3f3f3f3f3f3f;
struct LCT {
int fa[MAXN], ch[MAXN][2];
int val[MAXN], sum[MAXN], tag[MAXN];
void pushup(int x) { sum[x] = sum[ch[x][0]] ^ sum[ch[x][1]] ^ val[x]; }
void connect(int x, int f, int w) { fa[x] = f, (~w) && (ch[f][w] = x); }
int get(int x) {
if (ch[fa[x]][0] == x)
return 0;
if (ch[fa[x]][1] == x)
return 1;
return -1;
}
void pushdown(int x) {
if (tag[x]) {
swap(ch[x][0], ch[x][1]);
if (ch[x][0])
tag[ch[x][0]] ^= 1;
if (ch[x][1])
tag[ch[x][1]] ^= 1;
tag[x] = 0;
}
}
void pushall(int x) {
if (~get(x))
pushall(fa[x]);
pushdown(x);
}
void rotate(int x) {
int y = fa[x], z = fa[y], w = get(x);
if (~w)
connect(ch[x][w ^ 1], y, w);
connect(x, z, get(y));
connect(y, x, w ^ 1);
pushup(y);
pushup(x);
}
void splay(int x) {
pushall(x);
for (int f = fa[x]; f = fa[x], ~get(x); rotate(x))
if (~get(f))
rotate(get(f) == get(x) ? f : x);
}
void access(int x) {
int pre = 0;
while (x) {
splay(x);
ch[x][1] = pre;
pushup(x);
pre = x;
x = fa[x];
}
}
void makeroot(int x) {
access(x);
splay(x);
tag[x] ^= 1;
}
void split(int x, int y) {
makeroot(x);
access(y);
splay(y);
}
int findroot(int x) {
access(x);
splay(x);
while (ch[x][0])
pushdown(x), x = ch[x][0];
splay(x);
return x;
}
bool link(int x, int y) {
makeroot(x);
if (findroot(y) == x)
return false;
fa[x] = y;
return true;
}
bool cut(int x, int y) {
if (findroot(x) != findroot(y))
return false;
split(x, y);
if (fa[x] != y || ch[x][1])
return false;
fa[x] = ch[y][0] = 0;
return true;
}
};
int n, q;
LCT lct;
signed main() {
read(n), read(q);
for (int i = 1; i <= n; ++i) {
read(lct.val[i]);
lct.sum[i] = lct.val[i];
}
while (q--) {
int op, x, y;
read(op), read(x), read(y);
if (op == 0) {
lct.split(x, y);
write(lct.sum[y]), putchar('\n');
} else if (op == 1) {
lct.link(x, y);
} else if (op == 2) {
lct.cut(x, y);
} else {
lct.val[x] = y;
lct.splay(x);
}
}
return 0;
}