1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
#include <bits/stdc++.h> using namespace std; #define int long long #define resetIO(x) \ freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout) #define debug(fmt, ...) \ fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__) template <class _Tp> inline _Tp& read(_Tp &x) { bool sign = false; char ch = getchar(); long double tmp = 1; for (; !isdigit(ch); ch = getchar()) sign |= (ch == '-'); for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + (ch ^ 48); if (ch == '.') for (ch = getchar(); isdigit(ch); ch = getchar()) tmp /= 10.0, x += tmp * (ch ^ 48); return sign ? (x = -x) : x; } template <class _Tp> inline void write(_Tp x) { if (x < 0) putchar('-'), x = -x; if (x > 9) write(x / 10); putchar((x % 10) ^ 48); } const int MAXN = 1 << 21; const int INFL = 0x3f3f3f3f3f3f3f3f; const int MOD = 998244353; const int G0 = 3; int n, k, f[MAXN]; int qpow(int x, int y) { int ret = 1; for (; y; y >>= 1, x = x * x % MOD) if (y & 1) ret = ret * x % MOD; return ret; } void change(int *f, int len) { static int rev[MAXN]; for (int i = rev[0] = 0; i < len; ++i) { rev[i] = rev[i >> 1] >> 1; if (i & 1) rev[i] |= len >> 1; } for (int i = 0; i < len; ++i) if (i < rev[i]) swap(f[i], f[rev[i]]); } void ntt(int *f, int len, int on) { change(f, len); for (int h = 2; h <= len; h <<= 1) { int gn = qpow(G0, (MOD - 1) / h); for (int j = 0; j < len; j += h) { int g = 1; for (int k = j; k < j + h / 2; ++k) { int u = f[k], t = g * f[k + h / 2] % MOD; f[k] = (u + t + MOD) % MOD; f[k + h / 2] = (u - t + MOD) % MOD; g = g * gn % MOD; } } } if (on == -1) { reverse(f + 1, f + len); int inv = qpow(len, MOD - 2); for (int i = 0; i < len; ++i) f[i] = f[i] * inv % MOD; } } signed main() { read(n) /= 2, read(k); for (int i = 1, x; i <= k; ++i) read(x), f[x] = true; int len = 1; while (len <= n * 9) len <<= 1; ntt(f, len, 1); for (int i = 0; i < len; ++i) f[i] = qpow(f[i], n); ntt(f, len, -1); int ans = 0; for (int i = 0; i < len; ++i) (ans += f[i] * f[i]) %= MOD; write(ans), putchar('\n'); return 0; }
|