yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P3803 【模版】多项式乘法(FFT)

P3803 【模版】多项式乘法(FFT)

题面

题目链接

解法

FFT 模版题

为了防止自己忘记就写一下证明吧。顺便把 NTT 也写一下。

前置知识

多项式表示方法

系数表示法

点值表示法

不妨将多项式看成一个$n-1$次函数,从上面取$n$个点来唯一地表示这个函数。

设想一下高斯消元法,就能知道为什么$n$个不同的点就能唯一确定这个函数了。

这样的表示法有一个好处,就是如果要计算多项式乘法,设

那么

就可以$O(n)$地求出多项式乘法。

所以FFT要做的事就是将系数表示法和点值表示法进行转换。

复数(这个大家都会)

令$i^2 = -1$,复数可被表示为$a + bi$的形式

考虑在复平面上的两个向量$(a,b)$和$(c,d)$,将其表示的复数相乘得到$(a + bi) \times (c + di) = ac - bd + (ad + bc)i$,即向量$(ac - bd, ad + bc)$。

我们计算几个向量的模,分别为$\sqrt{a^2 + b^2}$,$\sqrt{c^2 + d^2}$和$\sqrt{a^2c^2 + b^2d^2 + a^2d^2 + b^2c^2} = \sqrt{(a^2 + b^2) \times (c^2 + d^2)}$,即两个向量模长的乘积。

所以如果两个原向量模长都为$1$,乘积的向量也为$1$。

假设我们有两个复平面上单位圆上的向量,设其辐角分别为$\alpha$和$\beta$,则这两个向量表示为$(\cos\alpha,\sin\alpha)$和$(\cos\beta,\sin\beta)$,其乘积为$(\cos\alpha\cos\beta-\sin\alpha\sin\beta,\cos\alpha\sin\beta+\sin\alpha\cos\beta)$。根据二角和差公式

可以发现这个新的向量的辐角就等于两个原向量的辐角相加。于是将得到结论:两个模长为$1$的向量相乘,得到的仍是模长为$1$的向量,辐角为两个向量辐角的和。

单位复根

由于我们要去计算若干个$x_i$对应的$f(x_i)$,最好的办法便是找一些特殊的数值带进去计算。这里引入单位复根的概念。

我们称$x^n = 1$在复数意义下的解是$n$次复根。显然这样的解有$n$个。设$\omega_n = e^{\frac{2\pi i}{n}}$,则$x^n = 1$的解集表示为${\omega^k_n \mid k=0,1,\cdots,n-1}$,称$w_n$为$n$次单位复根。根据复平面的知识,$n$次单位复根是复平面把单位圆$n$等分的第一个角所对应的向量,其他复根均可以用单位复根的幂表示。

所以显然还能得到$\omega_n = e^{\frac{2\pi i}{n}} = \cos(\frac{2\pi}{n}) + i \sin(\frac{2\pi}{n})$。

举个例子,$n=4$时,$w_n = i$,如图所示(图来自oi-wiki

fft-1

并且单位复根还有三个重要的性质如下:

终于开始讲FFT了

FFT其本质为分治算法。比方说对于

按照次数的奇偶来分组得到

右边提取$x$得到

按照奇偶次项建立新的函数

原来的$f(x)$可以被表示成

利用单位复根的性质得到

同理可得

而由于$n/2$需要一直为整数,所以$n$需要是$2^m$,不妨在一开始将多项式的次数补到长度为长度为$2^m$,次数为$2^m-1$的多项式即可。

代码实现方面,STL提供了复数的模版(我也是第一次知道有这种好事),可以使用std::complex<double>

NTT

模版和FFT基本相同,用原根代替单位复根(性质基本相同)。

时间复杂度$O(n \log n)$。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/**
* @file: P3803.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P3803
*/
// #pragma GCC optimize ("O2")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp &x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
using comp = complex<double>;
const int MAXN = 1 << 21;
const int INFL = 0x3f3f3f3f3f3f3f3f;
const int MOD = 998244353;
namespace fft {
int rev[MAXN];
void change(comp *f, int len) {
for (int i = 0; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
}
void fft(comp *f, int len, int on) {
change(f, len);
for (int h = 2; h <= len; h <<= 1) {
comp wn(cos(2 * M_PI / h), sin(2 * M_PI / h));
for (int j = 0; j < len; j += h) {
comp w(1, 0);
for (int k = j; k < j + h / 2; ++k) {
comp u = f[k], t = w * f[k + h / 2];
f[k] = u + t;
f[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1) {
reverse(f + 1, f + len);
for (int i = 0; i < len; ++i)
f[i].real(f[i].real() / len);
}
}
int n, m, len;
comp a[MAXN], b[MAXN], ans[MAXN];
signed main() {
read(n), read(m);
for (int i = 0; i <= n; ++i) {
int x;
read(x);
a[i].real(x);
}
for (int i = 0; i <= m; ++i) {
int x;
read(x);
b[i].real(x);
}
len = 1;
while (len <= n + m)
len *= 2;
fft(a, len, 1), fft(b, len, 1);
for (int i = 0; i < len; ++i)
ans[i] = a[i] * b[i];
fft(ans, len, -1);
for (int i = 0; i <= n + m; ++i)
printf("%lld%c", (int)(ans[i].real() + 0.5), " \n"[i == n + m]);
return 0;
}
}
namespace ntt {
int rev[MAXN];
int qpow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = x * x % MOD)
if (y & 1)
ret = ret * x % MOD;
return ret;
}
void change(int *f, int len) {
for (int i = 0; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
}
void ntt(int *f, int len, int on) {
change(f, len);
for (int h = 2; h <= len; h <<= 1) {
int gn = qpow(3, (MOD - 1) / h);
for (int j = 0; j < len; j += h) {
int g = 1;
for (int k = j; k < j + h / 2; ++k) {
int u = f[k], t = g * f[k + h / 2] % MOD;
f[k] = (u + t + MOD) % MOD;
f[k + h / 2] = (u - t + MOD) % MOD;
g = g * gn % MOD;
}
}
}
if (on == -1) {
reverse(f + 1, f + len);
int inv = qpow(len, MOD - 2);
for (int i = 0; i < len; ++i)
f[i] = f[i] * inv % MOD;
}
}
int n, m, len;
int a[MAXN], b[MAXN], ans[MAXN];
signed main() {
read(n), read(m);
for (int i = 0; i <= n; ++i)
read(a[i]);
for (int i = 0; i <= m; ++i)
read(b[i]);
len = 1;
while (len <= n + m)
len *= 2;
ntt(a, len, 1), ntt(b, len, 1);
for (int i = 0; i < len; ++i)
ans[i] = a[i] * b[i] % MOD;
ntt(ans, len, -1);
for (int i = 0; i <= n + m; ++i)
printf("%lld%c", ans[i], " \n"[i == n + m]);
return 0;
}
}
signed main() {
// return fft::main();
return ntt::main();
}