yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P4245 【模版】任意模数多项式乘法

P4245 【模版】任意模数多项式乘法

题面

题目链接

解法

对于任意模数的多项式乘法显然无法直接套用$NTT$。而由于数量级过大,采用std::complex<long double>的$FFT$也会产生比较多的精度丢失。

方法1: $NTT$合并

由于$a_i,b_i \le 10^9$且$n \le 10^5$,所以答案的数量级不超过$10^9 \times 10^9 \times 10^5 = 10^{23}$,这样就可以找$3$个$10^8$左右的模数分别用$NTT$算出对其取模的答案,用中国剩余定理合并。

这里比较常用的$3$个模数分别是

下面假设我们求出了对这$3$个模数的答案分别为$x_1,x_2,x_3$,得到同余方程如下:

由于乘积过大无法用long long储存,所以正常的$CRT$不能在此使用(当然如果你写高精度就当我没说),于是考虑手动合并。先把前两个合并:

这样便求出了$k_1$,同时得到新的方程$x \equiv x_1 + k_1 p_1 \pmod{p_1 p_2}$。记$x_4 = x_1 + k_1 p_1$。

求出$k_4$后,得到$x \equiv x_4 + k_4 p_1 p_2 \pmod{p_3}$。由于$x \lt p_1 p_2 p_3$,所以$x = x_4 + k_4 p_1 p_2$。

共需要跑$3 \times 3 = 9$次$NTT$,常数较大。

方法2: 拆系数$FFT$

把一个数拆分成$a \times 2^{15} + b$的形式,则$a, b \lt 2^{15}$

将$a$和$b$分别做多项式,相乘的值域是$2^{15} \times 2^{15} \times 10^5 \approx 10^{14}$,可以接受。于是

乍一看好像需要算$4$次乘法,共$12$次$FFT$,那岂不是比方法1更劣?

开始推式子吧。假设有$4$个多项式$A_1,A_2,B_1,B_2$,如何求他们的两两乘积。

由于$(a+bi) \times (c+di) = (ac-bd) + (ad+bc)i$,所以我们设复多项式$P = A_1 + iB_1$,$Q = A_2 + iB_2$(这是哪位神仙想出来的啊),而$FFT$本身就是使用复数计算所以直接传入复数也没关系。

设$T_1 = P \times Q = A_1 A_2 - B_1 B_2 + (A_1 B_2 + A_2 B_1)i$,

又设$P’ = A_1 - iB_1$,

那$T_2 = P’ \times Q = A_1 A_2 + B_1 B_2 + (A_1 B_2 - A_2 B_1)i$,

两者$T_1$和$T_2$进行和差就得到了多项式的两两乘积(妙啊)。

总的$FFT$次数为$3$次DFT$+2$次IDFT$=5$次。

突然发现值域其实应该是$10^{19}$而不是$10^{14}$,因为$IDFT$之前还得除以$n$……

看数据强度跑吧,代码倒是挺好写的。

计算$\pi$不要用常量M_PI而是要用acos(-1)!!!!不然这题丢失精度只能过前$10$个点!!!!

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
/**
* @file: P4245.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P4245
*/
// #pragma GCC optimize ("O2")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp &x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
using comp = complex<long double>;
const int MAXN = 1 << 20;
const int INFL = 0x3f3f3f3f3f3f3f3f;
const long double MPI = acos(-1);
// 对应方法1
namespace solve1 {
int mod;
void init(int mod) {
solve1::mod = mod;
}
int qmul(int x, int y, int p = mod) {
int ret = 0;
for (; y; y >>= 1, x = (x + x) % p)
if (y & 1)
ret = (ret + x) % p;
return ret;
}
int qpow(int x, int y, int p = mod) {
int ret = 1;
for (; y; y >>= 1, x = qmul(x, x, p))
if (y & 1)
ret = qmul(ret, x, p);
return ret;
}
void change(int *f, int len) {
static int rev[MAXN];
for (int i = rev[0] = 0; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
}
void ntt(int *f, int len, int on) {
change(f, len);
for (int h = 2; h <= len; h <<= 1) {
int gn = qpow(3, (mod - 1) / h);
for (int j = 0; j < len; j += h) {
int g = 1;
for (int k = j; k < j + h / 2; ++k) {
int u = f[k], t = g * f[k + h / 2] % mod;
f[k] = (u + t + mod) % mod;
f[k + h / 2] = (u - t + mod) % mod;
g = g * gn % mod;
}
}
}
if (on == -1) {
reverse(f + 1, f + len);
int inv = qpow(len, mod - 2);
for (int i = 0; i < len; ++i)
f[i] = f[i] * inv % mod;
}
}
const int MODS[3] = {998244353, 1004535809, 469762049};
const int INV1 = qpow(MODS[0], MODS[1] - 2, MODS[1]);
const int INV2 = qpow(MODS[0] * MODS[1] % MODS[2], MODS[2] - 2, MODS[2]);
int n, m, p, a[3][MAXN], b[3][MAXN], ans[3][MAXN];
int crt(int a1, int a2, int a3) {
int t = (a2 - a1 + MODS[1]) % MODS[1] * INV1 % MODS[1] * MODS[0] + a1;
return ((a3 - t % MODS[2] + MODS[2]) % MODS[2] * INV2 % MODS[2] * (MODS[0] * MODS[1] % p) % p + t) % p;
}
signed main() {
read(n), read(m), read(p);
for (int i = 0, x; i <= n; ++i)
read(x) %= p, a[0][i] = a[1][i] = a[2][i] = x;
for (int i = 0, x; i <= m; ++i)
read(x) %= p, b[0][i] = b[1][i] = b[2][i] = x;
int len = 1;
while (len < n + m + 1)
len <<= 1;
for (int k = 0; k < 3; ++k) {
init(MODS[k]);
ntt(a[k], len, 1), ntt(b[k], len, 1);
for (int i = 0; i < len; ++i)
ans[k][i] = a[k][i] * b[k][i] % mod;
ntt(ans[k], len, -1);
}
for (int i = 0; i <= n + m; ++i)
write(crt(ans[0][i], ans[1][i], ans[2][i])), putchar(" \n"[i == n + m]);
return 0;
}
}
// 对应方法2
namespace solve2 {
void change(comp *f, int len) {
static int rev[MAXN];
for (int i = rev[0] = 0; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(f[i], f[rev[i]]);
}
void fft(comp *f, int len, int on) {
change(f, len);
for (int h = 2; h <= len; h <<= 1) {
comp wn(cos(2 * MPI / h), sin(2 * MPI / h));
for (int j = 0; j < len; j += h) {
comp w(1, 0);
for (int k = j; k < j + h / 2; ++k) {
comp u = f[k], t = w * f[k + h / 2];
f[k] = u + t;
f[k + h / 2] = u - t;
w *= wn;
}
}
}
if (on == -1) {
reverse(f + 1, f + len);
for (int i = 0; i < len; ++i)
f[i] /= len;
}
}
const int BLOC = 1 << 15;
int n, m, p, a[MAXN], b[MAXN], ans[MAXN];
comp p1[MAXN], p2[MAXN], q[MAXN], t1[MAXN], t2[MAXN];
signed main() {
read(n), read(m), read(p);
for (int i = 0; i <= n; ++i)
read(a[i]) %= p;
for (int i = 0; i <= m; ++i)
read(b[i]) %= p;
int len = 1;
while (len < n + m + 1)
len <<= 1;
for (int i = 0; i < len; ++i)
p1[i] = comp(a[i] / BLOC, a[i] % BLOC);
for (int i = 0; i < len; ++i)
p2[i] = comp(a[i] / BLOC, -a[i] % BLOC);
for (int i = 0; i < len; ++i)
q[i] = comp(b[i] / BLOC, b[i] % BLOC);
fft(p1, len, 1), fft(p2, len, 1), fft(q, len, 1);
for (int i = 0; i < len; ++i)
t1[i] = p1[i] * q[i];
for (int i = 0; i < len; ++i)
t2[i] = p2[i] * q[i];
fft(t1, len, -1), fft(t2, len, -1);
for (int i = 0; i < len; ++i) {
int a1a2 = ((int)((t1[i].real() + t2[i].real()) / 2 + 0.5)) % p;
int a1b2 = ((int)((t2[i].imag() + t1[i].imag()) / 2 + 0.5)) % p;
int a2b1 = ((int)((t1[i].imag() - t2[i].imag()) / 2 + 0.5)) % p;
int b1b2 = ((int)((t2[i].real() - t1[i].real()) / 2 + 0.5)) % p;
ans[i] = (a1a2 * (1ll << 30) % p + (a1b2 + a2b1) * (1ll << 15) % p + b1b2) % p;
}
for (int i = 0; i <= n + m; ++i)
write(ans[i]), putchar(" \n"[i == n + m]);
return 0;
}
}
signed main() {
// return solve1::main();
return solve2::main();
}