yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P3233 世界树 题解

P3233 世界树 题解

题面

题目链接

解法

以下所有的“关键点”都表示题目中的议事处

看完这题题解后去学习了虚树(然后day2上午模拟赛考了道类似虚树但不用建$lca$的题目)

于是看到$\sum m \le 3 \times 10^5$ 直接建立虚树(然而如果没学过虚树就看不出来了)。

pair<int, int>储存树上节点到达最近关键点的距离和关键点编号以便于比较。先对虚树上的节点做一遍换根$dp$,方程显然是$dpu = \min\limits{v,w}{dp_v + w}$其中$v, w$分别为子节点和到子节点的距离。之后为了方便表述设$f_u$表示dp[u].first,$g_u$表示dp[u].second

接下来考虑不在虚树上的节点。首先如果虚树上某节点的一个原树上子节点的子树中都没有虚树上的节点,那么该子树中的的所有节点的最近关键点都应当与该虚树节点相同。形式化地,若对于虚树上的节点$u$有节点$u \to v{real}$使得不存在$v{real} \to \cdots \to x$在虚树上,那么$gx = g_u$,即$cnt{gu} += siz{v_{real}}$。如在下图中的节点中$v=6,v=7,v=10$都满足该种情况(加粗节点表示虚树节点)。

虚树

其次是在虚树的链上的点。不妨将例如上图中的节点$9$计算到节点$4$中(因为$g4$一定等于$g_9$),考虑一条链$fa \to u$,若$g{fa} = gu$则链上节点全部加到$g_u$中(同样设$v{real}$表示$fa$在原树中的$u$方向的子节点,相当于$cnt{g_u} += siz{v{real}} - siz_u$),否则倍增找到深度最小的满足取$g_u$更优的点$cur$,显然$cnt{gu} += siz{cur} - sizu$,$cnt{g{fa}} += siz{v{real}} - siz{cur}$。

于是本题就结束了。只是有太多太多细节要注意了,写错一个调一整年啊/kk

时间复杂度$O(n \log n)$稳过。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
/**
* @file: P3233.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P3233
*/
// #pragma GCC optimize ("O2")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define fi first
#define se second
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp &x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
using pii = pair<int, int>;
const int MAXN = 3e5 + 10;
const int MAXM = 6e5 + 10;
const int MAXK = 20;
const int INFL = 0x3f3f3f3f3f3f3f3f;
struct graph {
struct edge {
int u, v, w;
} e[MAXM];
int head[MAXN], nxt[MAXM], tot;
void addedge(int u, int v, int w = 1) {
e[++tot] = {u, v, w};
nxt[tot] = head[u];
head[u] = tot;
}
};
int n, m, q, top, inde, h[MAXN], dfn[MAXN], dep[MAXN];
int fa[MAXN][MAXK], siz[MAXN], sta[MAXN], cnt[MAXN], ori[MAXN];
pii dis[MAXN][2], tmp[MAXN], pre[MAXN], nxt[MAXN];
bool tag[MAXN];
graph tr, vtr;
void build_tr(int u, int f) {
siz[u] = 1, dfn[u] = ++inde, dep[u] = dep[f] + 1, fa[u][0] = f;
for (int i = 1; i < MAXK; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = tr.head[u]; i; i = tr.nxt[i])
if (tr.e[i].v != f)
build_tr(tr.e[i].v, u), siz[u] += siz[tr.e[i].v];
}
int lca(int u, int v) {
if (dep[u] < dep[v])
swap(u, v);
int t = dep[u] - dep[v];
for (int i = MAXK - 1; ~i; --i)
if ((t >> i) & 1)
u = fa[u][i];
if (u == v)
return u;
for (int i = MAXK - 1; ~i; --i)
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int jump(int u, int d) {
for (int i = MAXK - 1; ~i; --i)
if ((d >> i) & 1)
u = fa[u][i];
return u;
}
void add_vedge(int u, int v) {
// debug("virtual tree: %lld -> %lld, w = %lld", u, v, dep[v] - dep[u]);
vtr.addedge(u, v, dep[v] - dep[u]);
}
void build_vtr() {
sort(h + 1, h + m + 1, [](int x, int y) {
return dfn[x] < dfn[y];
});
sta[top = 1] = 1, vtr.head[1] = 0, vtr.tot = 0;
for (int i = 1; i <= m; ++i) {
if (h[i] == 1)
continue;
int l = lca(sta[top], h[i]);
while (dfn[l] <= dfn[sta[top - 1]])
add_vedge(sta[top - 1], sta[top]), --top;
if (sta[top] != l)
vtr.head[l] = 0, add_vedge(l, sta[top]), sta[top] = l;
vtr.head[h[i]] = 0, sta[++top] = h[i];
}
for (int i = 1; i < top; ++i)
add_vedge(sta[i], sta[i + 1]);
}
void dfs0(int u) {
dis[u][0] = tag[u] ? pii{0, u} : pii{INFL, 0};
for (int i = vtr.head[u]; i; i = vtr.nxt[i]) {
int v = vtr.e[i].v, w = vtr.e[i].w;
dfs0(v), dis[u][0] = min(dis[u][0], {dis[v][0].fi + w, dis[v][0].se});
}
// debug("dis of dfs0(%lld) = {%lld, %lld}", u, dis[u][0].fi, dis[u][0].se);
}
void dfs1(int u) {
int len = 0;
for (int i = vtr.head[u]; i; i = vtr.nxt[i]) {
int v = vtr.e[i].v, w = vtr.e[i].w;
tmp[++len] = {dis[v][0].fi + w, dis[v][0].se};
}
if (tag[u])
dis[u][1] = {0, u};
pre[0] = nxt[len + 1] = dis[u][1];
for (int i = 1; i <= len; ++i)
pre[i] = min(pre[i - 1], tmp[i]);
for (int i = len; i >= 1; --i)
nxt[i] = min(nxt[i + 1], tmp[i]);
len = 0;
for (int i = vtr.head[u]; i; i = vtr.nxt[i]) {
++len;
pii t = min(pre[len - 1], nxt[len + 1]);
// ##sb-mistakes## 换根$dp$不要在统计$pre$和$nxt$数组时就写$dfs$!!!数组整个改变!!!不然会死得很惨(指对着n=1000,m=100的大样例调1h)
dis[vtr.e[i].v][1] = {t.fi + vtr.e[i].w, t.se};
}
for (int i = vtr.head[u]; i; i = vtr.nxt[i])
dfs1(vtr.e[i].v);
// debug("dis of dfs1(%lld) = {%lld, %lld}", u, dis[u][1].fi, dis[u][1].se);
dis[u][1] = min(dis[u][0], dis[u][1]);
++cnt[dis[u][1].se];
// debug("dis of dfs1(%lld) = {%lld, %lld}", u, dis[u][1].fi, dis[u][1].se);
}
void dfs2(int u, int f) {
for (int i = vtr.head[u]; i; i = vtr.nxt[i])
dfs2(vtr.e[i].v, u);
if (f) {
// debug("calculate chain %lld -> %lld", f, u);
if (dis[f][1].se == dis[u][1].se) {
int cx = siz[jump(u, dep[u] - dep[f] - 1)] - siz[u];
// debug("same color, add %lld to cnt[%lld]", cx, dis[u][1].se);
cnt[dis[u][1].se] += cx;
} else {
int cur = u;
for (int i = MAXK - 1; ~i; --i) {
if (dep[fa[cur][i]] <= dep[f])
continue;
pii up{dep[fa[cur][i]] - dep[f] + dis[f][1].fi, dis[f][1].se};
pii dn{dep[u] - dep[fa[cur][i]] + dis[u][1].fi, dis[u][1].se};
if (dn < up)
cur = fa[cur][i];
}
// debug("vertex is %lld", cur);
int c1 = siz[cur] - siz[u];
int c2 = siz[jump(u, dep[u] - dep[f] - 1)] - siz[cur];
// debug("%lld", siz[jump(u, dep[u] - dep[f] - 1)]);
// debug("diff color, add %lld to cnt[%lld]", c1, dis[u][1].se);
// debug("diff color, add %lld to cnt[%lld]", c2, dis[f][1].se);
cnt[dis[u][1].se] += c1;
cnt[dis[f][1].se] += c2;
}
}
// debug("calculate subtree of %lld", u);
int rem = siz[u] - 1;
for (int i = vtr.head[u]; i; i = vtr.nxt[i])
rem -= siz[jump(vtr.e[i].v, dep[vtr.e[i].v] - dep[u] - 1)];
// debug("add %lld to cnt[%lld]", rem, dis[u][1].se);
cnt[dis[u][1].se] += rem;
}
signed main() {
read(n);
for (int i = 1; i < n; ++i) {
int u, v;
read(u), read(v);
tr.addedge(u, v);
tr.addedge(v, u);
}
build_tr(1, 0);
read(q);
while (q--) {
read(m);
for (int i = 1; i <= m; ++i)
ori[i] = read(h[i]), tag[h[i]] = true, cnt[h[i]] = 0;
build_vtr();
dis[1][0] = dis[1][1] = {INFL, 0};
dfs0(1), dfs1(1), dfs2(1, 0);
for (int i = 1; i <= m; ++i)
write(cnt[ori[i]]), putchar(i == m ? '\n' : ' ');
for (int i = 1; i <= m; ++i)
tag[h[i]] = false;
}
return 0;
}