yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P4002 生成树计数

P4002 生成树计数

黑,真$^{\text{TM}}$的黑啊(指题目的颜色)

题面

题目链接

前置知识:求数列幂和

给定序列${an}$,对于$0\le t \le k$,求$\sum\limits{i=1}^n a_i^t$。

先写出幂次和所求值之间的生成函数:

不妨设$G(x)=\sum\limits_{i=1}^{n}(\ln’(1-a_ix))$,则有

于是求出$G(x)$后可以很容易地求出$F(x)$。思考如何求$G(x)$:

使用分治FFT求出$\prod\limits_{i=1}^n(1-a_ix)$,再$\ln$和求导,求出$G(x)$和$F(x)$,时间复杂度$O(n \log^2 n)$。

多项式$\ln \mid \exp$

P4725 【模版】多项式对数函数(多项式ln)P4726 【模版】多项式指数函数(多项式exp)

解法

首先,你需要想到使用prufer序列来进行树形态的枚举并计算贡献。

根据prufer序列的定义,一棵$n$个节点的树共有$n^{n-2}$种不同的形态。

我们计算答案时需要的是每个节点的度$deg_u$。

为了方便设$d_u=deg_u-1$,$d_u$的实际意义是$u$在prufer序列中的出现次数。

所以枚举$\sum d_i=n-2$,然后prufer序列的$n-2$个数字排列共$(n-2)!$种情况。于是乎

左边两项定值先不管,右边几项如下:

构造两个多项式$A(x)$和$B(x)$:

则有

当我们求出$\frac{A(x)}{B(x)}$和$\ln(B(x))$以后,要想求出$\sum\limitsi \frac{A(a_ix)}{B(a_ix)}$和$\sum\limits_i \ln(B(a_ix))$,则需要将每项系数乘上$\sum\limits{i=1}^na_i^k$,用前置知识中的方法分治计算。

总的时间复杂度为$O(n \log^2 n)$。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
/**
* @file: P4002.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P4002
*/
// #pragma GCC optimize ("O2")
// #pragma GCC optimize ("Ofast", "inline", "-ffast-math")
// #pragma GCC target ("avx,sse2,sse3,sse4,mmx")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp& x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
const int MAXN = 1 << 17;
const int INF = 0x3f3f3f3f3f3f3f3f;
const int MOD = 998244353;
namespace maths {
int add(int x, int y) {
x += y;
return x >= MOD ? x - MOD : x;
}
int sub(int x, int y) {
x -= y;
return x < 0 ? x + MOD : x;
}
int qpow(int x, int y, int p = MOD) {
int ret = 1;
for (; y; y >>= 1, x = x * x % p)
if (y & 1)
ret = ret * x % p;
return ret;
}
template <class _Tp>
void change(_Tp* f, int len) {
static int rev[MAXN] = {};
for (int i = 0; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
}
void ntt(int* f, int len, int on) {
change(f, len);
for (int h = 2; h <= len; h <<= 1) {
int gn = qpow(3, (MOD - 1) / h);
for (int j = 0; j < len; j += h) {
int g = 1;
for (int k = j; k < j + h / 2; ++k) {
int u = f[k], t = g * f[k + h / 2] % MOD;
f[k] = add(u, t), f[k + h / 2] = sub(u, t);
g = g * gn % MOD;
}
}
}
if (on == -1) {
reverse(f + 1, f + len);
int inv = qpow(len, MOD - 2);
for (int i = 0; i < len; ++i)
f[i] = f[i] * inv % MOD;
}
}
void poly_multiply(const int* f, int n, const int* g, int m, int* ans) {
static int tf[MAXN] = {}, tg[MAXN] = {};
int len = 1;
while (len < n + m - 1)
len <<= 1;
copy(f, f + n, tf);
fill(tf + n, tf + len, 0);
copy(g, g + m, tg);
fill(tg + m, tg + len, 0);
fill(ans, ans + len, 0);
ntt(tf, len, 1);
ntt(tg, len, 1);
for (int i = 0; i < len; ++i)
tf[i] = tf[i] * tg[i] % MOD;
ntt(tf, len, -1);
copy(tf, tf + n + m - 1, ans);
}
void poly_inverse(const int* f, int n, int* ans) {
static int tmp[MAXN] = {};
int len = 1;
while (len < n)
len <<= 1;
fill(ans, ans + len + len, 0);
ans[0] = qpow(f[0], MOD - 2);
for (int h = 2; h <= len; h <<= 1) {
copy(f, f + h, tmp);
fill(tmp + h, tmp + h + h, 0);
ntt(tmp, h + h, 1);
ntt(ans, h + h, 1);
for (int i = 0; i < h + h; ++i)
ans[i] = ans[i] * (2 - ans[i] * tmp[i] % MOD + MOD) % MOD;
ntt(ans, h + h, -1);
fill(ans + h, ans + h + h, 0);
}
}
void poly_derivation(const int* f, int n, int* ans) {
for (int i = 0; i < n - 1; ++i)
ans[i] = f[i + 1] * (i + 1) % MOD;
ans[n - 1] = 0;
}
void poly_integral(const int* f, int n, int* ans) {
for (int i = n; i >= 1; --i)
ans[i] = f[i - 1] * qpow(i, MOD - 2) % MOD;
ans[0] = 0;
}
void poly_ln(const int* f, int n, int* ans) {
static int tf[MAXN] = {}, tg[MAXN] = {};
poly_derivation(f, n, tf);
poly_inverse(f, n, tg);
poly_multiply(tf, n - 1, tg, n, ans);
poly_integral(ans, n - 1, ans);
fill(ans + n, ans + n + n, 0);
}
void poly_exp(const int* f, int n, int* ans) {
static int tf[MAXN] = {}, tg[MAXN] = {};
ans[0] = 1, ans[1] = 0;
for (int h = 2; h <= (n << 1); h <<= 1) {
poly_ln(ans, h, tf);
for (int i = 0; i < h; ++i)
tf[i] = add((i == 0), sub(f[i], tf[i]));
copy(ans, ans + h, tg);
poly_multiply(tf, h, tg, h, ans);
}
}
}; // namespace maths
using namespace maths;
int n, m, a[MAXN], ta[MAXN], tb[MAXN], tc[MAXN], ts[MAXN];
int fac[MAXN], inv[MAXN];
void solve(int l, int r, int cl, int cr) {
if (l + 1 == r) {
ts[cl] = 1, ts[cl + 1] = sub(0, a[l]);
return;
}
int mid = (l + r) >> 1, cmid = (cl + cr) >> 1;
solve(l, mid, cl, cmid);
solve(mid, r, cmid, cr);
poly_multiply(ts + cl, cmid - cl, ts + cmid, cr - cmid, ts + cl);
}
signed main() {
read(n), read(m);
for (int i = 1; i <= n; ++i)
read(a[i]);
fac[0] = 1;
for (int i = 1; i <= n; ++i)
fac[i] = fac[i - 1] * i % MOD;
inv[n] = qpow(fac[n], MOD - 2);
for (int i = n - 1; ~i; --i)
inv[i] = inv[i + 1] * (i + 1) % MOD;
int len = 1;
while (len <= n)
len <<= 1;
for (int i = 0; i < len; ++i)
ta[i] = qpow(i + 1, m + m) * inv[i] % MOD;
for (int i = 0; i < len; ++i)
tb[i] = qpow(i + 1, m) * inv[i] % MOD;
poly_inverse(tb, len, tc);
poly_multiply(ta, len, tc, len, ta);
poly_ln(tb, len, tc);
copy(tc, tc + len, tb);
solve(0, len, 0, len << 1);
poly_ln(ts, len, ts);
for (int i = 1; i <= n; ++i)
ts[i] = sub(0, ts[i] * i % MOD);
ts[0] = n;
fill(ts + n, ts + len, 0);
for (int i = 0; i < len; ++i) {
ta[i] = ta[i] * ts[i] % MOD;
tb[i] = tb[i] * ts[i] % MOD;
}
poly_exp(tb, len, tc);
poly_multiply(ta, len, tc, len, ta);
int ans = fac[n - 2] * ta[n - 2] % MOD;
for (int i = 1; i <= n; ++i)
ans = ans * a[i] % MOD;
write(ans), putchar('\n');
return 0;
}